Spaces:
Build error
Build error
| package helper | |
| import ( | |
| "errors" | |
| "fmt" | |
| "math" | |
| "one-api/common" | |
| "one-api/dto" | |
| "one-api/logger" | |
| relayconstant "one-api/relay/constant" | |
| "one-api/types" | |
| "strings" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { | |
| relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) | |
| switch format { | |
| case types.RelayFormatOpenAI: | |
| request, err = GetAndValidateTextRequest(c, relayMode) | |
| case types.RelayFormatGemini: | |
| if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { | |
| request, err = GetAndValidateGeminiEmbeddingRequest(c) | |
| } else { | |
| request, err = GetAndValidateGeminiRequest(c) | |
| } | |
| case types.RelayFormatClaude: | |
| request, err = GetAndValidateClaudeRequest(c) | |
| case types.RelayFormatOpenAIResponses: | |
| request, err = GetAndValidateResponsesRequest(c) | |
| case types.RelayFormatOpenAIImage: | |
| request, err = GetAndValidOpenAIImageRequest(c, relayMode) | |
| case types.RelayFormatEmbedding: | |
| request, err = GetAndValidateEmbeddingRequest(c, relayMode) | |
| case types.RelayFormatRerank: | |
| request, err = GetAndValidateRerankRequest(c) | |
| case types.RelayFormatOpenAIAudio: | |
| request, err = GetAndValidAudioRequest(c, relayMode) | |
| case types.RelayFormatOpenAIRealtime: | |
| request = &dto.BaseRequest{} | |
| default: | |
| return nil, fmt.Errorf("unsupported relay format: %s", format) | |
| } | |
| return request, err | |
| } | |
| func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { | |
| audioRequest := &dto.AudioRequest{} | |
| err := common.UnmarshalBodyReusable(c, audioRequest) | |
| if err != nil { | |
| return nil, err | |
| } | |
| switch relayMode { | |
| case relayconstant.RelayModeAudioSpeech: | |
| if audioRequest.Model == "" { | |
| return nil, errors.New("model is required") | |
| } | |
| default: | |
| err = c.Request.ParseForm() | |
| if err != nil { | |
| return nil, err | |
| } | |
| formData := c.Request.PostForm | |
| if audioRequest.Model == "" { | |
| audioRequest.Model = formData.Get("model") | |
| } | |
| if audioRequest.Model == "" { | |
| return nil, errors.New("model is required") | |
| } | |
| audioRequest.ResponseFormat = formData.Get("response_format") | |
| if audioRequest.ResponseFormat == "" { | |
| audioRequest.ResponseFormat = "json" | |
| } | |
| } | |
| return audioRequest, nil | |
| } | |
| func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { | |
| var rerankRequest *dto.RerankRequest | |
| err := common.UnmarshalBodyReusable(c, &rerankRequest) | |
| if err != nil { | |
| logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) | |
| return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) | |
| } | |
| if rerankRequest.Query == "" { | |
| return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) | |
| } | |
| if len(rerankRequest.Documents) == 0 { | |
| return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) | |
| } | |
| return rerankRequest, nil | |
| } | |
| func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { | |
| var embeddingRequest *dto.EmbeddingRequest | |
| err := common.UnmarshalBodyReusable(c, &embeddingRequest) | |
| if err != nil { | |
| logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) | |
| return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) | |
| } | |
| if embeddingRequest.Input == nil { | |
| return nil, fmt.Errorf("input is empty") | |
| } | |
| if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { | |
| embeddingRequest.Model = "omni-moderation-latest" | |
| } | |
| if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { | |
| embeddingRequest.Model = c.Param("model") | |
| } | |
| return embeddingRequest, nil | |
| } | |
| func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { | |
| request := &dto.OpenAIResponsesRequest{} | |
| err := common.UnmarshalBodyReusable(c, request) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if request.Model == "" { | |
| return nil, errors.New("model is required") | |
| } | |
| if request.Input == nil { | |
| return nil, errors.New("input is required") | |
| } | |
| return request, nil | |
| } | |
| func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { | |
| imageRequest := &dto.ImageRequest{} | |
| switch relayMode { | |
| case relayconstant.RelayModeImagesEdits: | |
| if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { | |
| _, err := c.MultipartForm() | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to parse image edit form request: %w", err) | |
| } | |
| formData := c.Request.PostForm | |
| imageRequest.Prompt = formData.Get("prompt") | |
| imageRequest.Model = formData.Get("model") | |
| imageRequest.N = uint(common.String2Int(formData.Get("n"))) | |
| imageRequest.Quality = formData.Get("quality") | |
| imageRequest.Size = formData.Get("size") | |
| if imageRequest.Model == "gpt-image-1" { | |
| if imageRequest.Quality == "" { | |
| imageRequest.Quality = "standard" | |
| } | |
| } | |
| if imageRequest.N == 0 { | |
| imageRequest.N = 1 | |
| } | |
| watermark := formData.Has("watermark") | |
| if watermark { | |
| imageRequest.Watermark = &watermark | |
| } | |
| break | |
| } | |
| fallthrough | |
| default: | |
| err := common.UnmarshalBodyReusable(c, imageRequest) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if imageRequest.Model == "" { | |
| //imageRequest.Model = "dall-e-3" | |
| return nil, errors.New("model is required") | |
| } | |
| if strings.Contains(imageRequest.Size, "×") { | |
| return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") | |
| } | |
| // Not "256x256", "512x512", or "1024x1024" | |
| if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { | |
| if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | |
| return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") | |
| } | |
| if imageRequest.Size == "" { | |
| imageRequest.Size = "1024x1024" | |
| } | |
| } else if imageRequest.Model == "dall-e-3" { | |
| if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { | |
| return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") | |
| } | |
| if imageRequest.Quality == "" { | |
| imageRequest.Quality = "standard" | |
| } | |
| if imageRequest.Size == "" { | |
| imageRequest.Size = "1024x1024" | |
| } | |
| } else if imageRequest.Model == "gpt-image-1" { | |
| if imageRequest.Quality == "" { | |
| imageRequest.Quality = "auto" | |
| } | |
| } | |
| //if imageRequest.Prompt == "" { | |
| // return nil, errors.New("prompt is required") | |
| //} | |
| if imageRequest.N == 0 { | |
| imageRequest.N = 1 | |
| } | |
| } | |
| return imageRequest, nil | |
| } | |
| func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { | |
| textRequest = &dto.ClaudeRequest{} | |
| err = c.ShouldBindJSON(textRequest) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | |
| return nil, errors.New("field messages is required") | |
| } | |
| if textRequest.Model == "" { | |
| return nil, errors.New("field model is required") | |
| } | |
| //if textRequest.Stream { | |
| // relayInfo.IsStream = true | |
| //} | |
| return textRequest, nil | |
| } | |
| func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { | |
| textRequest := &dto.GeneralOpenAIRequest{} | |
| err := common.UnmarshalBodyReusable(c, textRequest) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { | |
| textRequest.Model = "text-moderation-latest" | |
| } | |
| if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { | |
| textRequest.Model = c.Param("model") | |
| } | |
| if textRequest.MaxTokens > math.MaxInt32/2 { | |
| return nil, errors.New("max_tokens is invalid") | |
| } | |
| if textRequest.Model == "" { | |
| return nil, errors.New("model is required") | |
| } | |
| if textRequest.WebSearchOptions != nil { | |
| if textRequest.WebSearchOptions.SearchContextSize != "" { | |
| validSizes := map[string]bool{ | |
| "high": true, | |
| "medium": true, | |
| "low": true, | |
| } | |
| if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { | |
| return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") | |
| } | |
| } else { | |
| textRequest.WebSearchOptions.SearchContextSize = "medium" | |
| } | |
| } | |
| switch relayMode { | |
| case relayconstant.RelayModeCompletions: | |
| if textRequest.Prompt == "" { | |
| return nil, errors.New("field prompt is required") | |
| } | |
| case relayconstant.RelayModeChatCompletions: | |
| if len(textRequest.Messages) == 0 { | |
| return nil, errors.New("field messages is required") | |
| } | |
| case relayconstant.RelayModeEmbeddings: | |
| case relayconstant.RelayModeModerations: | |
| if textRequest.Input == nil || textRequest.Input == "" { | |
| return nil, errors.New("field input is required") | |
| } | |
| case relayconstant.RelayModeEdits: | |
| if textRequest.Instruction == "" { | |
| return nil, errors.New("field instruction is required") | |
| } | |
| } | |
| return textRequest, nil | |
| } | |
| func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { | |
| request := &dto.GeminiChatRequest{} | |
| err := common.UnmarshalBodyReusable(c, request) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if len(request.Contents) == 0 { | |
| return nil, errors.New("contents is required") | |
| } | |
| //if c.Query("alt") == "sse" { | |
| // relayInfo.IsStream = true | |
| //} | |
| return request, nil | |
| } | |
| func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { | |
| request := &dto.GeminiEmbeddingRequest{} | |
| err := common.UnmarshalBodyReusable(c, request) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return request, nil | |
| } | |