| package service |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "io" |
| "log" |
| "net/http" |
| "strconv" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/setting" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| func CoverActionToModelName(mjAction string) string { |
| modelName := "mj_" + strings.ToLower(mjAction) |
| if mjAction == constant.MjActionSwapFace { |
| modelName = "swap_face" |
| } |
| return modelName |
| } |
|
|
| func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { |
| action := "" |
| if relayMode == relayconstant.RelayModeMidjourneyAction { |
| |
| err := CoverPlusActionToNormalAction(midjRequest) |
| if err != nil { |
| return "", err, false |
| } |
| action = midjRequest.Action |
| } else { |
| switch relayMode { |
| case relayconstant.RelayModeMidjourneyImagine: |
| action = constant.MjActionImagine |
| case relayconstant.RelayModeMidjourneyVideo: |
| action = constant.MjActionVideo |
| case relayconstant.RelayModeMidjourneyEdits: |
| action = constant.MjActionEdits |
| case relayconstant.RelayModeMidjourneyDescribe: |
| action = constant.MjActionDescribe |
| case relayconstant.RelayModeMidjourneyBlend: |
| action = constant.MjActionBlend |
| case relayconstant.RelayModeMidjourneyShorten: |
| action = constant.MjActionShorten |
| case relayconstant.RelayModeMidjourneyChange: |
| action = midjRequest.Action |
| case relayconstant.RelayModeMidjourneyModal: |
| action = constant.MjActionModal |
| case relayconstant.RelayModeSwapFace: |
| action = constant.MjActionSwapFace |
| case relayconstant.RelayModeMidjourneyUpload: |
| action = constant.MjActionUpload |
| case relayconstant.RelayModeMidjourneySimpleChange: |
| params := ConvertSimpleChangeParams(midjRequest.Content) |
| if params == nil { |
| return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false |
| } |
| action = params.Action |
| case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: |
| return "", nil, true |
| default: |
| return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false |
| } |
| } |
| modelName := CoverActionToModelName(action) |
| return modelName, nil, true |
| } |
|
|
| func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { |
| |
| customId := midjRequest.CustomId |
| if customId == "" { |
| return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") |
| } |
| splits := strings.Split(customId, "::") |
| var action string |
| if splits[1] == "JOB" { |
| action = splits[2] |
| } else { |
| action = splits[1] |
| } |
|
|
| if action == "" { |
| return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") |
| } |
| if strings.Contains(action, "upsample") { |
| index, err := strconv.Atoi(splits[3]) |
| if err != nil { |
| return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") |
| } |
| midjRequest.Index = index |
| midjRequest.Action = constant.MjActionUpscale |
| } else if strings.Contains(action, "variation") { |
| midjRequest.Index = 1 |
| if action == "variation" { |
| index, err := strconv.Atoi(splits[3]) |
| if err != nil { |
| return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") |
| } |
| midjRequest.Index = index |
| midjRequest.Action = constant.MjActionVariation |
| } else if action == "low_variation" { |
| midjRequest.Action = constant.MjActionLowVariation |
| } else if action == "high_variation" { |
| midjRequest.Action = constant.MjActionHighVariation |
| } |
| } else if strings.Contains(action, "pan") { |
| midjRequest.Action = constant.MjActionPan |
| midjRequest.Index = 1 |
| } else if strings.Contains(action, "reroll") { |
| midjRequest.Action = constant.MjActionReRoll |
| midjRequest.Index = 1 |
| } else if action == "Outpaint" { |
| midjRequest.Action = constant.MjActionZoom |
| midjRequest.Index = 1 |
| } else if action == "CustomZoom" { |
| midjRequest.Action = constant.MjActionCustomZoom |
| midjRequest.Index = 1 |
| } else if action == "Inpaint" { |
| midjRequest.Action = constant.MjActionInPaint |
| midjRequest.Index = 1 |
| } else { |
| return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) |
| } |
| return nil |
| } |
|
|
| func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { |
| split := strings.Split(content, " ") |
| if len(split) != 2 { |
| return nil |
| } |
|
|
| action := strings.ToLower(split[1]) |
| changeParams := &dto.MidjourneyRequest{} |
| changeParams.TaskId = split[0] |
|
|
| if action[0] == 'u' { |
| changeParams.Action = "UPSCALE" |
| } else if action[0] == 'v' { |
| changeParams.Action = "VARIATION" |
| } else if action == "r" { |
| changeParams.Action = "REROLL" |
| return changeParams |
| } else { |
| return nil |
| } |
|
|
| index, err := strconv.Atoi(action[1:2]) |
| if err != nil || index < 1 || index > 4 { |
| return nil |
| } |
| changeParams.Index = index |
| return changeParams |
| } |
|
|
| func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { |
| var nullBytes []byte |
| |
| |
| |
| var mapResult map[string]interface{} |
| |
| if c.Request.Method != "GET" { |
| err := json.NewDecoder(c.Request.Body).Decode(&mapResult) |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err |
| } |
| if !setting.MjAccountFilterEnabled { |
| delete(mapResult, "accountFilter") |
| } |
| if !setting.MjNotifyEnabled { |
| delete(mapResult, "notifyHook") |
| } |
| |
| |
| } |
| if setting.MjModeClearEnabled { |
| if prompt, ok := mapResult["prompt"].(string); ok { |
| prompt = strings.Replace(prompt, "--fast", "", -1) |
| prompt = strings.Replace(prompt, "--relax", "", -1) |
| prompt = strings.Replace(prompt, "--turbo", "", -1) |
|
|
| mapResult["prompt"] = prompt |
| } |
| } |
| reqBody, err := json.Marshal(mapResult) |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody))) |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err |
| } |
| ctx, cancel := context.WithTimeout(context.Background(), timeout) |
| |
| req = req.WithContext(ctx) |
| req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| req.Header.Set("Accept", c.Request.Header.Get("Accept")) |
| auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey) |
| if auth != "" { |
| auth = strings.TrimPrefix(auth, "Bearer ") |
| req.Header.Set("mj-api-secret", auth) |
| } |
| defer cancel() |
| resp, err := GetHttpClient().Do(req) |
| if err != nil { |
| common.SysLog("do request failed: " + err.Error()) |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err |
| } |
| statusCode := resp.StatusCode |
| |
| |
| |
| err = req.Body.Close() |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err |
| } |
| err = c.Request.Body.Close() |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err |
| } |
| var midjResponse dto.MidjourneyResponse |
| var midjourneyUploadsResponse dto.MidjourneyUploadResponse |
| responseBody, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err |
| } |
| CloseResponseBodyGracefully(resp) |
| respStr := string(responseBody) |
| log.Printf("respStr: %s", respStr) |
| if respStr == "" { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil |
| } else { |
| err = json.Unmarshal(responseBody, &midjResponse) |
| if err != nil { |
| err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse) |
| if err2 != nil { |
| return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err |
| } |
| } |
| } |
| |
| |
| |
| |
| return &dto.MidjourneyResponseWithStatusCode{ |
| StatusCode: statusCode, |
| Response: midjResponse, |
| }, responseBody, nil |
| } |
|
|