| | package relay |
| |
|
| | import ( |
| | "bytes" |
| | "fmt" |
| | "io" |
| | "net/http" |
| | "strings" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/dto" |
| | "github.com/QuantumNous/new-api/logger" |
| | "github.com/QuantumNous/new-api/relay/channel/gemini" |
| | relaycommon "github.com/QuantumNous/new-api/relay/common" |
| | "github.com/QuantumNous/new-api/relay/helper" |
| | "github.com/QuantumNous/new-api/service" |
| | "github.com/QuantumNous/new-api/setting/model_setting" |
| | "github.com/QuantumNous/new-api/types" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { |
| | if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { |
| | configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget |
| | if configBudget != nil && *configBudget == 0 { |
| | |
| | return true |
| | } |
| | } |
| | return false |
| | } |
| |
|
| | func trimModelThinking(modelName string) string { |
| | |
| | if strings.HasSuffix(modelName, "-nothinking") { |
| | return strings.TrimSuffix(modelName, "-nothinking") |
| | } |
| | |
| | if strings.HasSuffix(modelName, "-thinking") { |
| | return strings.TrimSuffix(modelName, "-thinking") |
| | } |
| |
|
| | |
| | if strings.Contains(modelName, "-thinking-") { |
| | parts := strings.Split(modelName, "-thinking-") |
| | if len(parts) > 1 { |
| | return parts[0] + "-thinking" |
| | } |
| | } |
| | return modelName |
| | } |
| |
|
| | func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { |
| | info.InitChannelMeta(c) |
| |
|
| | geminiReq, ok := info.Request.(*dto.GeminiChatRequest) |
| | if !ok { |
| | return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | request, err := common.DeepCopy(geminiReq) |
| | if err != nil { |
| | return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | |
| | err = helper.ModelMappedHelper(c, info, request) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { |
| | if isNoThinkingRequest(request) { |
| | |
| | if !strings.Contains(info.OriginModelName, "-nothinking") { |
| | |
| | noThinkingModelName := info.OriginModelName + "-nothinking" |
| | containPrice := helper.ContainPriceOrRatio(noThinkingModelName) |
| | if containPrice { |
| | info.OriginModelName = noThinkingModelName |
| | info.UpstreamModelName = noThinkingModelName |
| | } |
| | } |
| | } |
| | if request.GenerationConfig.ThinkingConfig == nil { |
| | gemini.ThinkingAdaptor(request, info) |
| | } |
| | } |
| |
|
| | adaptor := GetAdaptor(info.ApiType) |
| | if adaptor == nil { |
| | return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | adaptor.Init(info) |
| |
|
| | if info.ChannelSetting.SystemPrompt != "" { |
| | if request.SystemInstructions == nil { |
| | request.SystemInstructions = &dto.GeminiChatContent{ |
| | Parts: []dto.GeminiPart{ |
| | {Text: info.ChannelSetting.SystemPrompt}, |
| | }, |
| | } |
| | } else if len(request.SystemInstructions.Parts) == 0 { |
| | request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}} |
| | } else if info.ChannelSetting.SystemPromptOverride { |
| | common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) |
| | merged := false |
| | for i := range request.SystemInstructions.Parts { |
| | if request.SystemInstructions.Parts[i].Text == "" { |
| | continue |
| | } |
| | request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text |
| | merged = true |
| | break |
| | } |
| | if !merged { |
| | request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...) |
| | } |
| | } |
| | } |
| |
|
| | |
| | if request.SystemInstructions != nil { |
| | hasContent := false |
| | for _, part := range request.SystemInstructions.Parts { |
| | if part.Text != "" { |
| | hasContent = true |
| | break |
| | } |
| | } |
| | if !hasContent { |
| | request.SystemInstructions = nil |
| | } |
| | } |
| |
|
| | var requestBody io.Reader |
| | if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { |
| | body, err := common.GetRequestBody(c) |
| | if err != nil { |
| | return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) |
| | } |
| | requestBody = bytes.NewReader(body) |
| | } else { |
| | |
| | convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) |
| | } |
| | jsonData, err := common.Marshal(convertedRequest) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | |
| | if len(info.ParamOverride) > 0 { |
| | jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) |
| | } |
| | } |
| |
|
| | logger.LogDebug(c, "Gemini request body: "+string(jsonData)) |
| |
|
| | requestBody = bytes.NewReader(jsonData) |
| | } |
| |
|
| | resp, err := adaptor.DoRequest(c, info, requestBody) |
| | if err != nil { |
| | logger.LogError(c, "Do gemini request failed: "+err.Error()) |
| | return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) |
| | } |
| |
|
| | statusCodeMappingStr := c.GetString("status_code_mapping") |
| |
|
| | var httpResp *http.Response |
| | if resp != nil { |
| | httpResp = resp.(*http.Response) |
| | info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") |
| | if httpResp.StatusCode != http.StatusOK { |
| | newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) |
| | |
| | service.ResetStatusCode(newAPIError, statusCodeMappingStr) |
| | return newAPIError |
| | } |
| | } |
| |
|
| | usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) |
| | if openaiErr != nil { |
| | service.ResetStatusCode(openaiErr, statusCodeMappingStr) |
| | return openaiErr |
| | } |
| |
|
| | postConsumeQuota(c, info, usage.(*dto.Usage), "") |
| | return nil |
| | } |
| |
|
| | func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { |
| | info.InitChannelMeta(c) |
| |
|
| | isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") |
| | info.IsGeminiBatchEmbedding = isBatch |
| |
|
| | var req dto.Request |
| | var err error |
| | var inputTexts []string |
| |
|
| | if isBatch { |
| | batchRequest := &dto.GeminiBatchEmbeddingRequest{} |
| | err = common.UnmarshalBodyReusable(c, batchRequest) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
| | } |
| | req = batchRequest |
| | for _, r := range batchRequest.Requests { |
| | for _, part := range r.Content.Parts { |
| | if part.Text != "" { |
| | inputTexts = append(inputTexts, part.Text) |
| | } |
| | } |
| | } |
| | } else { |
| | singleRequest := &dto.GeminiEmbeddingRequest{} |
| | err = common.UnmarshalBodyReusable(c, singleRequest) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
| | } |
| | req = singleRequest |
| | for _, part := range singleRequest.Content.Parts { |
| | if part.Text != "" { |
| | inputTexts = append(inputTexts, part.Text) |
| | } |
| | } |
| | } |
| |
|
| | err = helper.ModelMappedHelper(c, info, req) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | req.SetModelName("models/" + info.UpstreamModelName) |
| |
|
| | adaptor := GetAdaptor(info.ApiType) |
| | if adaptor == nil { |
| | return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) |
| | } |
| | adaptor.Init(info) |
| |
|
| | var requestBody io.Reader |
| | jsonData, err := common.Marshal(req) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) |
| | } |
| |
|
| | |
| | if len(info.ParamOverride) > 0 { |
| | reqMap := make(map[string]interface{}) |
| | _ = common.Unmarshal(jsonData, &reqMap) |
| | for key, value := range info.ParamOverride { |
| | reqMap[key] = value |
| | } |
| | jsonData, err = common.Marshal(reqMap) |
| | if err != nil { |
| | return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) |
| | } |
| | } |
| | logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData)) |
| | requestBody = bytes.NewReader(jsonData) |
| |
|
| | resp, err := adaptor.DoRequest(c, info, requestBody) |
| | if err != nil { |
| | logger.LogError(c, "Do gemini request failed: "+err.Error()) |
| | return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) |
| | } |
| |
|
| | statusCodeMappingStr := c.GetString("status_code_mapping") |
| | var httpResp *http.Response |
| | if resp != nil { |
| | httpResp = resp.(*http.Response) |
| | if httpResp.StatusCode != http.StatusOK { |
| | newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) |
| | service.ResetStatusCode(newAPIError, statusCodeMappingStr) |
| | return newAPIError |
| | } |
| | } |
| |
|
| | usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) |
| | if openaiErr != nil { |
| | service.ResetStatusCode(openaiErr, statusCodeMappingStr) |
| | return openaiErr |
| | } |
| |
|
| | postConsumeQuota(c, info, usage.(*dto.Usage), "") |
| | return nil |
| | } |
| |
|