|
|
package helper |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
"errors" |
|
|
"fmt" |
|
|
"math" |
|
|
"strings" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/logger" |
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant" |
|
|
"github.com/QuantumNous/new-api/types" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { |
|
|
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) |
|
|
|
|
|
switch format { |
|
|
case types.RelayFormatOpenAI: |
|
|
request, err = GetAndValidateTextRequest(c, relayMode) |
|
|
case types.RelayFormatGemini: |
|
|
if strings.Contains(c.Request.URL.Path, ":embedContent") { |
|
|
request, err = GetAndValidateGeminiEmbeddingRequest(c) |
|
|
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { |
|
|
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c) |
|
|
} else { |
|
|
request, err = GetAndValidateGeminiRequest(c) |
|
|
} |
|
|
case types.RelayFormatClaude: |
|
|
request, err = GetAndValidateClaudeRequest(c) |
|
|
case types.RelayFormatOpenAIResponses: |
|
|
request, err = GetAndValidateResponsesRequest(c) |
|
|
|
|
|
case types.RelayFormatOpenAIImage: |
|
|
request, err = GetAndValidOpenAIImageRequest(c, relayMode) |
|
|
case types.RelayFormatEmbedding: |
|
|
request, err = GetAndValidateEmbeddingRequest(c, relayMode) |
|
|
case types.RelayFormatRerank: |
|
|
request, err = GetAndValidateRerankRequest(c) |
|
|
case types.RelayFormatOpenAIAudio: |
|
|
request, err = GetAndValidAudioRequest(c, relayMode) |
|
|
case types.RelayFormatOpenAIRealtime: |
|
|
request = &dto.BaseRequest{} |
|
|
default: |
|
|
return nil, fmt.Errorf("unsupported relay format: %s", format) |
|
|
} |
|
|
return request, err |
|
|
} |
|
|
|
|
|
func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { |
|
|
audioRequest := &dto.AudioRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, audioRequest) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
switch relayMode { |
|
|
case relayconstant.RelayModeAudioSpeech: |
|
|
if audioRequest.Model == "" { |
|
|
return nil, errors.New("model is required") |
|
|
} |
|
|
default: |
|
|
if audioRequest.Model == "" { |
|
|
return nil, errors.New("model is required") |
|
|
} |
|
|
if audioRequest.ResponseFormat == "" { |
|
|
audioRequest.ResponseFormat = "json" |
|
|
} |
|
|
} |
|
|
return audioRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { |
|
|
var rerankRequest *dto.RerankRequest |
|
|
err := common.UnmarshalBodyReusable(c, &rerankRequest) |
|
|
if err != nil { |
|
|
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) |
|
|
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
|
|
} |
|
|
|
|
|
if rerankRequest.Query == "" { |
|
|
return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
|
|
} |
|
|
if len(rerankRequest.Documents) == 0 { |
|
|
return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
|
|
} |
|
|
return rerankRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { |
|
|
var embeddingRequest *dto.EmbeddingRequest |
|
|
err := common.UnmarshalBodyReusable(c, &embeddingRequest) |
|
|
if err != nil { |
|
|
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) |
|
|
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) |
|
|
} |
|
|
|
|
|
if embeddingRequest.Input == nil { |
|
|
return nil, fmt.Errorf("input is empty") |
|
|
} |
|
|
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { |
|
|
embeddingRequest.Model = "omni-moderation-latest" |
|
|
} |
|
|
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { |
|
|
embeddingRequest.Model = c.Param("model") |
|
|
} |
|
|
return embeddingRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { |
|
|
request := &dto.OpenAIResponsesRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, request) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
if request.Model == "" { |
|
|
return nil, errors.New("model is required") |
|
|
} |
|
|
if request.Input == nil { |
|
|
return nil, errors.New("input is required") |
|
|
} |
|
|
return request, nil |
|
|
} |
|
|
|
|
|
func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { |
|
|
imageRequest := &dto.ImageRequest{} |
|
|
|
|
|
switch relayMode { |
|
|
case relayconstant.RelayModeImagesEdits: |
|
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { |
|
|
_, err := c.MultipartForm() |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to parse image edit form request: %w", err) |
|
|
} |
|
|
formData := c.Request.PostForm |
|
|
imageRequest.Prompt = formData.Get("prompt") |
|
|
imageRequest.Model = formData.Get("model") |
|
|
imageRequest.N = uint(common.String2Int(formData.Get("n"))) |
|
|
imageRequest.Quality = formData.Get("quality") |
|
|
imageRequest.Size = formData.Get("size") |
|
|
imageRequest.ResponseFormat = formData.Get("response_format") |
|
|
if imageValue := formData.Get("image"); imageValue != "" { |
|
|
imageRequest.Image, _ = json.Marshal(imageValue) |
|
|
} |
|
|
|
|
|
if imageRequest.Model == "gpt-image-1" { |
|
|
if imageRequest.Quality == "" { |
|
|
imageRequest.Quality = "standard" |
|
|
} |
|
|
} |
|
|
if imageRequest.N == 0 { |
|
|
imageRequest.N = 1 |
|
|
} |
|
|
|
|
|
hasWatermark := formData.Has("watermark") |
|
|
if hasWatermark { |
|
|
watermark := formData.Get("watermark") == "true" |
|
|
imageRequest.Watermark = &watermark |
|
|
} |
|
|
break |
|
|
} |
|
|
fallthrough |
|
|
default: |
|
|
err := common.UnmarshalBodyReusable(c, imageRequest) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
if imageRequest.Model == "" { |
|
|
|
|
|
return nil, errors.New("model is required") |
|
|
} |
|
|
|
|
|
if strings.Contains(imageRequest.Size, "×") { |
|
|
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") |
|
|
} |
|
|
|
|
|
|
|
|
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { |
|
|
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { |
|
|
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") |
|
|
} |
|
|
if imageRequest.Size == "" { |
|
|
imageRequest.Size = "1024x1024" |
|
|
} |
|
|
} else if imageRequest.Model == "dall-e-3" { |
|
|
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { |
|
|
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") |
|
|
} |
|
|
if imageRequest.Quality == "" { |
|
|
imageRequest.Quality = "standard" |
|
|
} |
|
|
if imageRequest.Size == "" { |
|
|
imageRequest.Size = "1024x1024" |
|
|
} |
|
|
} else if imageRequest.Model == "gpt-image-1" { |
|
|
if imageRequest.Quality == "" { |
|
|
imageRequest.Quality = "auto" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if imageRequest.N == 0 { |
|
|
imageRequest.N = 1 |
|
|
} |
|
|
} |
|
|
|
|
|
return imageRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { |
|
|
textRequest = &dto.ClaudeRequest{} |
|
|
err = c.ShouldBindJSON(textRequest) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
if textRequest.Messages == nil || len(textRequest.Messages) == 0 { |
|
|
return nil, errors.New("field messages is required") |
|
|
} |
|
|
if textRequest.Model == "" { |
|
|
return nil, errors.New("field model is required") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return textRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { |
|
|
textRequest := &dto.GeneralOpenAIRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, textRequest) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { |
|
|
textRequest.Model = "text-moderation-latest" |
|
|
} |
|
|
if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { |
|
|
textRequest.Model = c.Param("model") |
|
|
} |
|
|
|
|
|
if textRequest.MaxTokens > math.MaxInt32/2 { |
|
|
return nil, errors.New("max_tokens is invalid") |
|
|
} |
|
|
if textRequest.Model == "" { |
|
|
return nil, errors.New("model is required") |
|
|
} |
|
|
if textRequest.WebSearchOptions != nil { |
|
|
if textRequest.WebSearchOptions.SearchContextSize != "" { |
|
|
validSizes := map[string]bool{ |
|
|
"high": true, |
|
|
"medium": true, |
|
|
"low": true, |
|
|
} |
|
|
if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { |
|
|
return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") |
|
|
} |
|
|
} else { |
|
|
textRequest.WebSearchOptions.SearchContextSize = "medium" |
|
|
} |
|
|
} |
|
|
switch relayMode { |
|
|
case relayconstant.RelayModeCompletions: |
|
|
if textRequest.Prompt == "" { |
|
|
return nil, errors.New("field prompt is required") |
|
|
} |
|
|
case relayconstant.RelayModeChatCompletions: |
|
|
|
|
|
|
|
|
if len(textRequest.Messages) == 0 && textRequest.Prefix == nil && textRequest.Suffix == nil { |
|
|
return nil, errors.New("field messages is required") |
|
|
} |
|
|
case relayconstant.RelayModeEmbeddings: |
|
|
case relayconstant.RelayModeModerations: |
|
|
if textRequest.Input == nil || textRequest.Input == "" { |
|
|
return nil, errors.New("field input is required") |
|
|
} |
|
|
case relayconstant.RelayModeEdits: |
|
|
if textRequest.Instruction == "" { |
|
|
return nil, errors.New("field instruction is required") |
|
|
} |
|
|
} |
|
|
return textRequest, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { |
|
|
request := &dto.GeminiChatRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, request) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
if len(request.Contents) == 0 && len(request.Requests) == 0 { |
|
|
return nil, errors.New("contents is required") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return request, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { |
|
|
request := &dto.GeminiEmbeddingRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, request) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
return request, nil |
|
|
} |
|
|
|
|
|
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) { |
|
|
request := &dto.GeminiBatchEmbeddingRequest{} |
|
|
err := common.UnmarshalBodyReusable(c, request) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
return request, nil |
|
|
} |
|
|
|