package common import ( "encoding/base64" "errors" "fmt" "io" "net/http" "strconv" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/gin-gonic/gin" "github.com/samber/lo" ) type HasPrompt interface { GetPrompt() string } type HasImage interface { HasImage() bool } func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { case constant.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) case constant.ChannelTypeAzure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } return fullRequestURL } func GetAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { apiVersion = c.GetString("api_version") } return apiVersion } func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError { return &dto.TaskError{ Code: code, Message: err.Error(), StatusCode: statusCode, LocalError: localError, Error: err, } } func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) { info.Action = action c.Set("task_request", requestObj) } func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) { v, exists := c.Get("task_request") if !exists { return TaskSubmitReq{}, fmt.Errorf("request not found in context") } req, ok := v.(TaskSubmitReq) if !ok { return TaskSubmitReq{}, fmt.Errorf("invalid task request type") } return req, nil } func validatePrompt(prompt string) *dto.TaskError { if strings.TrimSpace(prompt) == "" { return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true) } return nil } func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) { var req TaskSubmitReq if _, err := c.MultipartForm(); err != nil { return req, err } formData := c.Request.PostForm req = TaskSubmitReq{ Prompt: formData.Get("prompt"), Model: formData.Get("model"), Mode: formData.Get("mode"), Image: formData.Get("image"), Size: formData.Get("size"), Metadata: make(map[string]interface{}), } if durationStr := formData.Get("seconds"); durationStr != "" { if duration, err := strconv.Atoi(durationStr); err == nil { req.Duration = duration } } if images := formData["images"]; len(images) > 0 { req.Images = images } for key, values := range formData { if len(values) > 0 && !isKnownTaskField(key) { if intVal, err := strconv.Atoi(values[0]); err == nil { req.Metadata[key] = intVal } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil { req.Metadata[key] = floatVal } else { req.Metadata[key] = values[0] } } } return req, nil } func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { var prompt string var model string var seconds int var size string var hasInputReference bool var req TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return createTaskError(err, "invalid_json", http.StatusBadRequest, true) } prompt = req.Prompt model = req.Model size = req.Size seconds, _ = strconv.Atoi(req.Seconds) if seconds == 0 { seconds = req.Duration } if req.InputReference != "" { req.Images = []string{req.InputReference} } if strings.TrimSpace(req.Model) == "" { return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) } if req.HasImage() { hasInputReference = true } if taskErr := validatePrompt(prompt); taskErr != nil { return taskErr } action := constant.TaskActionTextGenerate if hasInputReference { action = constant.TaskActionGenerate } if strings.HasPrefix(model, "sora-2") { if size == "" { size = "720x1280" } if seconds <= 0 { seconds = 4 } if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } info.PriceData.OtherRatios = map[string]float64{ "seconds": float64(seconds), "size": 1, } if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { info.PriceData.OtherRatios["size"] = 1.666667 } } info.Action = action return nil } func isKnownTaskField(field string) bool { knownFields := map[string]bool{ "prompt": true, "model": true, "mode": true, "image": true, "images": true, "size": true, "duration": true, "input_reference": true, // Sora 特有字段 } return knownFields[field] } func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { var err error contentType := c.GetHeader("Content-Type") var req TaskSubmitReq if strings.HasPrefix(contentType, "multipart/form-data") { req, err = validateMultipartTaskRequest(c, info, action) if err != nil { return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) } } else if err := common.UnmarshalBodyReusable(c, &req); err != nil { return createTaskError(err, "invalid_request", http.StatusBadRequest, true) } if taskErr := validatePrompt(req.Prompt); taskErr != nil { return taskErr } if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" { // 兼容单图上传 req.Images = []string{req.Image} } storeTaskRequest(c, info, action, req) return nil } func GetImagesBase64sFromForm(c *gin.Context) ([]*Base64Data, error) { return GetBase64sFromForm(c, "image") } func GetImageBase64sFromForm(c *gin.Context) (*Base64Data, error) { base64s, err := GetImagesBase64sFromForm(c) if err != nil { return nil, err } return base64s[0], nil } type Base64Data struct { MimeType string Data string } func (m Base64Data) String() string { return fmt.Sprintf("data:%s;base64,%s", m.MimeType, m.Data) } func GetBase64sFromForm(c *gin.Context, fieldName string) ([]*Base64Data, error) { mf := c.Request.MultipartForm if mf == nil { if _, err := c.MultipartForm(); err != nil { return nil, fmt.Errorf("failed to parse image edit form request: %w", err) } mf = c.Request.MultipartForm } imageFiles, exists := mf.File[fieldName] if !exists || len(imageFiles) == 0 { return nil, errors.New("field " + fieldName + " is not found or empty") } var imageBase64s []*Base64Data for _, file := range imageFiles { image, err := file.Open() if err != nil { return nil, errors.New("failed to open image file") } defer image.Close() imageData, err := io.ReadAll(image) if err != nil { return nil, errors.New("failed to read image file") } mimeType := http.DetectContentType(imageData) base64Data := base64.StdEncoding.EncodeToString(imageData) imageBase64s = append(imageBase64s, &Base64Data{ MimeType: mimeType, Data: base64Data, }) } return imageBase64s, nil }