| | package common |
| |
|
| | import ( |
| | "encoding/base64" |
| | "errors" |
| | "fmt" |
| | "io" |
| | "net/http" |
| | "strconv" |
| | "strings" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/dto" |
| |
|
| | "github.com/gin-gonic/gin" |
| | "github.com/samber/lo" |
| | ) |
| |
|
| | type HasPrompt interface { |
| | GetPrompt() string |
| | } |
| |
|
| | type HasImage interface { |
| | HasImage() bool |
| | } |
| |
|
| | func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { |
| | fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |
| |
|
| | if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { |
| | switch channelType { |
| | case constant.ChannelTypeOpenAI: |
| | fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) |
| | case constant.ChannelTypeAzure: |
| | fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) |
| | } |
| | } |
| | return fullRequestURL |
| | } |
| |
|
| | func GetAPIVersion(c *gin.Context) string { |
| | query := c.Request.URL.Query() |
| | apiVersion := query.Get("api-version") |
| | if apiVersion == "" { |
| | apiVersion = c.GetString("api_version") |
| | } |
| | return apiVersion |
| | } |
| |
|
| | func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError { |
| | return &dto.TaskError{ |
| | Code: code, |
| | Message: err.Error(), |
| | StatusCode: statusCode, |
| | LocalError: localError, |
| | Error: err, |
| | } |
| | } |
| |
|
| | func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) { |
| | info.Action = action |
| | c.Set("task_request", requestObj) |
| | } |
| | func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) { |
| | v, exists := c.Get("task_request") |
| | if !exists { |
| | return TaskSubmitReq{}, fmt.Errorf("request not found in context") |
| | } |
| | req, ok := v.(TaskSubmitReq) |
| | if !ok { |
| | return TaskSubmitReq{}, fmt.Errorf("invalid task request type") |
| | } |
| | return req, nil |
| | } |
| |
|
| | func validatePrompt(prompt string) *dto.TaskError { |
| | if strings.TrimSpace(prompt) == "" { |
| | return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true) |
| | } |
| | return nil |
| | } |
| |
|
| | func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) { |
| | var req TaskSubmitReq |
| | if _, err := c.MultipartForm(); err != nil { |
| | return req, err |
| | } |
| |
|
| | formData := c.Request.PostForm |
| | req = TaskSubmitReq{ |
| | Prompt: formData.Get("prompt"), |
| | Model: formData.Get("model"), |
| | Mode: formData.Get("mode"), |
| | Image: formData.Get("image"), |
| | Size: formData.Get("size"), |
| | Metadata: make(map[string]interface{}), |
| | } |
| |
|
| | if durationStr := formData.Get("seconds"); durationStr != "" { |
| | if duration, err := strconv.Atoi(durationStr); err == nil { |
| | req.Duration = duration |
| | } |
| | } |
| |
|
| | if images := formData["images"]; len(images) > 0 { |
| | req.Images = images |
| | } |
| |
|
| | for key, values := range formData { |
| | if len(values) > 0 && !isKnownTaskField(key) { |
| | if intVal, err := strconv.Atoi(values[0]); err == nil { |
| | req.Metadata[key] = intVal |
| | } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil { |
| | req.Metadata[key] = floatVal |
| | } else { |
| | req.Metadata[key] = values[0] |
| | } |
| | } |
| | } |
| | return req, nil |
| | } |
| |
|
| | func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { |
| | var prompt string |
| | var model string |
| | var seconds int |
| | var size string |
| | var hasInputReference bool |
| |
|
| | var req TaskSubmitReq |
| | if err := common.UnmarshalBodyReusable(c, &req); err != nil { |
| | return createTaskError(err, "invalid_json", http.StatusBadRequest, true) |
| | } |
| |
|
| | prompt = req.Prompt |
| | model = req.Model |
| | size = req.Size |
| | seconds, _ = strconv.Atoi(req.Seconds) |
| | if seconds == 0 { |
| | seconds = req.Duration |
| | } |
| | if req.InputReference != "" { |
| | req.Images = []string{req.InputReference} |
| | } |
| |
|
| | if strings.TrimSpace(req.Model) == "" { |
| | return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) |
| | } |
| |
|
| | if req.HasImage() { |
| | hasInputReference = true |
| | } |
| |
|
| | if taskErr := validatePrompt(prompt); taskErr != nil { |
| | return taskErr |
| | } |
| |
|
| | action := constant.TaskActionTextGenerate |
| | if hasInputReference { |
| | action = constant.TaskActionGenerate |
| | } |
| | if strings.HasPrefix(model, "sora-2") { |
| |
|
| | if size == "" { |
| | size = "720x1280" |
| | } |
| |
|
| | if seconds <= 0 { |
| | seconds = 4 |
| | } |
| |
|
| | if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) { |
| | return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) |
| | } |
| | if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { |
| | return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) |
| | } |
| | info.PriceData.OtherRatios = map[string]float64{ |
| | "seconds": float64(seconds), |
| | "size": 1, |
| | } |
| | if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { |
| | info.PriceData.OtherRatios["size"] = 1.666667 |
| | } |
| | } |
| |
|
| | info.Action = action |
| |
|
| | return nil |
| | } |
| |
|
| | func isKnownTaskField(field string) bool { |
| | knownFields := map[string]bool{ |
| | "prompt": true, |
| | "model": true, |
| | "mode": true, |
| | "image": true, |
| | "images": true, |
| | "size": true, |
| | "duration": true, |
| | "input_reference": true, |
| | } |
| | return knownFields[field] |
| | } |
| |
|
| | func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { |
| | var err error |
| | contentType := c.GetHeader("Content-Type") |
| | var req TaskSubmitReq |
| | if strings.HasPrefix(contentType, "multipart/form-data") { |
| | req, err = validateMultipartTaskRequest(c, info, action) |
| | if err != nil { |
| | return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) |
| | } |
| | } else if err := common.UnmarshalBodyReusable(c, &req); err != nil { |
| | return createTaskError(err, "invalid_request", http.StatusBadRequest, true) |
| | } |
| |
|
| | if taskErr := validatePrompt(req.Prompt); taskErr != nil { |
| | return taskErr |
| | } |
| |
|
| | if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" { |
| | |
| | req.Images = []string{req.Image} |
| | } |
| |
|
| | storeTaskRequest(c, info, action, req) |
| | return nil |
| | } |
| | func GetImagesBase64sFromForm(c *gin.Context) ([]*Base64Data, error) { |
| | return GetBase64sFromForm(c, "image") |
| | } |
| | func GetImageBase64sFromForm(c *gin.Context) (*Base64Data, error) { |
| | base64s, err := GetImagesBase64sFromForm(c) |
| | if err != nil { |
| | return nil, err |
| | } |
| | return base64s[0], nil |
| | } |
| |
|
| | type Base64Data struct { |
| | MimeType string |
| | Data string |
| | } |
| |
|
| | func (m Base64Data) String() string { |
| | return fmt.Sprintf("data:%s;base64,%s", m.MimeType, m.Data) |
| | } |
| | func GetBase64sFromForm(c *gin.Context, fieldName string) ([]*Base64Data, error) { |
| | mf := c.Request.MultipartForm |
| | if mf == nil { |
| | if _, err := c.MultipartForm(); err != nil { |
| | return nil, fmt.Errorf("failed to parse image edit form request: %w", err) |
| | } |
| | mf = c.Request.MultipartForm |
| | } |
| | imageFiles, exists := mf.File[fieldName] |
| | if !exists || len(imageFiles) == 0 { |
| | return nil, errors.New("field " + fieldName + " is not found or empty") |
| | } |
| | var imageBase64s []*Base64Data |
| | for _, file := range imageFiles { |
| | image, err := file.Open() |
| | if err != nil { |
| | return nil, errors.New("failed to open image file") |
| | } |
| | defer image.Close() |
| | imageData, err := io.ReadAll(image) |
| | if err != nil { |
| | return nil, errors.New("failed to read image file") |
| | } |
| | mimeType := http.DetectContentType(imageData) |
| | base64Data := base64.StdEncoding.EncodeToString(imageData) |
| | imageBase64s = append(imageBase64s, &Base64Data{ |
| | MimeType: mimeType, |
| | Data: base64Data, |
| | }) |
| | } |
| | return imageBase64s, nil |
| | } |
| |
|