| | package relay |
| |
|
| | import ( |
| | "bytes" |
| | "encoding/json" |
| | "fmt" |
| | "io" |
| | "log" |
| | "net/http" |
| | "strconv" |
| | "strings" |
| | "time" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/dto" |
| | "github.com/QuantumNous/new-api/model" |
| | relaycommon "github.com/QuantumNous/new-api/relay/common" |
| | relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| | "github.com/QuantumNous/new-api/relay/helper" |
| | "github.com/QuantumNous/new-api/service" |
| | "github.com/QuantumNous/new-api/setting" |
| | "github.com/QuantumNous/new-api/setting/system_setting" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | func RelayMidjourneyImage(c *gin.Context) { |
| | taskId := c.Param("id") |
| | midjourneyTask := model.GetByOnlyMJId(taskId) |
| | if midjourneyTask == nil { |
| | c.JSON(400, gin.H{ |
| | "error": "midjourney_task_not_found", |
| | }) |
| | return |
| | } |
| | var httpClient *http.Client |
| | if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { |
| | proxy := channel.GetSetting().Proxy |
| | if proxy != "" { |
| | if httpClient, err = service.NewProxyHttpClient(proxy); err != nil { |
| | c.JSON(400, gin.H{ |
| | "error": "proxy_url_invalid", |
| | }) |
| | return |
| | } |
| | } |
| | } |
| | if httpClient == nil { |
| | httpClient = service.GetHttpClient() |
| | } |
| | resp, err := httpClient.Get(midjourneyTask.ImageUrl) |
| | if err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "error": "http_get_image_failed", |
| | }) |
| | return |
| | } |
| | defer resp.Body.Close() |
| | if resp.StatusCode != http.StatusOK { |
| | responseBody, _ := io.ReadAll(resp.Body) |
| | c.JSON(resp.StatusCode, gin.H{ |
| | "error": string(responseBody), |
| | }) |
| | return |
| | } |
| | |
| | contentType := resp.Header.Get("Content-Type") |
| | if contentType == "" { |
| | |
| | contentType = "image/jpeg" |
| | } |
| | |
| | c.Writer.Header().Set("Content-Type", contentType) |
| | |
| | _, err = io.Copy(c.Writer, resp.Body) |
| | if err != nil { |
| | log.Println("Failed to stream image:", err) |
| | } |
| | return |
| | } |
| |
|
| | func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { |
| | var midjRequest dto.MidjourneyDto |
| | err := common.UnmarshalBodyReusable(c, &midjRequest) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "bind_request_body_failed", |
| | Properties: nil, |
| | Result: "", |
| | } |
| | } |
| | midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) |
| | if midjourneyTask == nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "midjourney_task_not_found", |
| | Properties: nil, |
| | Result: "", |
| | } |
| | } |
| | midjourneyTask.Progress = midjRequest.Progress |
| | midjourneyTask.PromptEn = midjRequest.PromptEn |
| | midjourneyTask.State = midjRequest.State |
| | midjourneyTask.SubmitTime = midjRequest.SubmitTime |
| | midjourneyTask.StartTime = midjRequest.StartTime |
| | midjourneyTask.FinishTime = midjRequest.FinishTime |
| | midjourneyTask.ImageUrl = midjRequest.ImageUrl |
| | midjourneyTask.VideoUrl = midjRequest.VideoUrl |
| | videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls) |
| | midjourneyTask.VideoUrls = string(videoUrlsStr) |
| | midjourneyTask.Status = midjRequest.Status |
| | midjourneyTask.FailReason = midjRequest.FailReason |
| | err = midjourneyTask.Update() |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "update_midjourney_task_failed", |
| | } |
| | } |
| |
|
| | return nil |
| | } |
| |
|
| | func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { |
| | midjourneyTask.MjId = originTask.MjId |
| | midjourneyTask.Progress = originTask.Progress |
| | midjourneyTask.PromptEn = originTask.PromptEn |
| | midjourneyTask.State = originTask.State |
| | midjourneyTask.SubmitTime = originTask.SubmitTime |
| | midjourneyTask.StartTime = originTask.StartTime |
| | midjourneyTask.FinishTime = originTask.FinishTime |
| | midjourneyTask.ImageUrl = "" |
| | if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled { |
| | midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId |
| | if originTask.Status != "SUCCESS" { |
| | midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) |
| | } |
| | } else { |
| | midjourneyTask.ImageUrl = originTask.ImageUrl |
| | } |
| | if originTask.VideoUrl != "" { |
| | midjourneyTask.VideoUrl = originTask.VideoUrl |
| | } |
| | midjourneyTask.Status = originTask.Status |
| | midjourneyTask.FailReason = originTask.FailReason |
| | midjourneyTask.Action = originTask.Action |
| | midjourneyTask.Description = originTask.Description |
| | midjourneyTask.Prompt = originTask.Prompt |
| | if originTask.Buttons != "" { |
| | var buttons []dto.ActionButton |
| | err := json.Unmarshal([]byte(originTask.Buttons), &buttons) |
| | if err == nil { |
| | midjourneyTask.Buttons = buttons |
| | } |
| | } |
| | if originTask.VideoUrls != "" { |
| | var videoUrls []dto.ImgUrls |
| | err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls) |
| | if err == nil { |
| | midjourneyTask.VideoUrls = videoUrls |
| | } |
| | } |
| | if originTask.Properties != "" { |
| | var properties dto.Properties |
| | err := json.Unmarshal([]byte(originTask.Properties), &properties) |
| | if err == nil { |
| | midjourneyTask.Properties = &properties |
| | } |
| | } |
| | return |
| | } |
| |
|
| | func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { |
| | var swapFaceRequest dto.SwapFaceRequest |
| | err := common.UnmarshalBodyReusable(c, &swapFaceRequest) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") |
| | } |
| |
|
| | info.InitChannelMeta(c) |
| |
|
| | if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") |
| | } |
| | modelName := service.CoverActionToModelName(constant.MjActionSwapFace) |
| |
|
| | priceData := helper.ModelPriceHelperPerCall(c, info) |
| |
|
| | userQuota, err := model.GetUserQuota(info.UserId, false) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: err.Error(), |
| | } |
| | } |
| |
|
| | if userQuota-priceData.Quota < 0 { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "quota_not_enough", |
| | } |
| | } |
| | requestURL := getMjRequestPath(c.Request.URL.String()) |
| | baseURL := c.GetString("base_url") |
| | fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |
| | mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) |
| | if err != nil { |
| | return &mjResp.Response |
| | } |
| | defer func() { |
| | if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { |
| | err := service.PostConsumeQuota(info, priceData.Quota, 0, true) |
| | if err != nil { |
| | common.SysLog("error consuming token remain quota: " + err.Error()) |
| | } |
| |
|
| | tokenName := c.GetString("token_name") |
| | logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) |
| | other := service.GenerateMjOtherInfo(info, priceData) |
| | model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ |
| | ChannelId: info.ChannelId, |
| | ModelName: modelName, |
| | TokenName: tokenName, |
| | Quota: priceData.Quota, |
| | Content: logContent, |
| | TokenId: info.TokenId, |
| | Group: info.UsingGroup, |
| | Other: other, |
| | }) |
| | model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) |
| | model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) |
| | } |
| | }() |
| | midjResponse := &mjResp.Response |
| | midjourneyTask := &model.Midjourney{ |
| | UserId: info.UserId, |
| | Code: midjResponse.Code, |
| | Action: constant.MjActionSwapFace, |
| | MjId: midjResponse.Result, |
| | Prompt: "InsightFace", |
| | PromptEn: "", |
| | Description: midjResponse.Description, |
| | State: "", |
| | SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), |
| | StartTime: time.Now().UnixNano() / int64(time.Millisecond), |
| | FinishTime: 0, |
| | ImageUrl: "", |
| | Status: "", |
| | Progress: "0%", |
| | FailReason: "", |
| | ChannelId: c.GetInt("channel_id"), |
| | Quota: priceData.Quota, |
| | } |
| | err = midjourneyTask.Insert() |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed") |
| | } |
| | c.Writer.WriteHeader(mjResp.StatusCode) |
| | respBody, err := json.Marshal(midjResponse) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") |
| | } |
| | _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") |
| | } |
| | return nil |
| | } |
| |
|
| | func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { |
| | taskId := c.Param("id") |
| | userId := c.GetInt("id") |
| | originTask := model.GetByMJId(userId, taskId) |
| | if originTask == nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") |
| | } |
| | channel, err := model.GetChannelById(originTask.ChannelId, true) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") |
| | } |
| | if channel.Status != common.ChannelStatusEnabled { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") |
| | } |
| | c.Set("channel_id", originTask.ChannelId) |
| | c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
| |
|
| | requestURL := getMjRequestPath(c.Request.URL.String()) |
| | fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) |
| | midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) |
| | if err != nil { |
| | return &midjResponseWithStatus.Response |
| | } |
| | midjResponse := &midjResponseWithStatus.Response |
| | c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) |
| | respBody, err := json.Marshal(midjResponse) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") |
| | } |
| | service.IOCopyBytesGracefully(c, nil, respBody) |
| | return nil |
| | } |
| |
|
| | func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { |
| | userId := c.GetInt("id") |
| | var err error |
| | var respBody []byte |
| | switch relayMode { |
| | case relayconstant.RelayModeMidjourneyTaskFetch: |
| | taskId := c.Param("id") |
| | originTask := model.GetByMJId(userId, taskId) |
| | if originTask == nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "task_no_found", |
| | } |
| | } |
| | midjourneyTask := coverMidjourneyTaskDto(c, originTask) |
| | respBody, err = json.Marshal(midjourneyTask) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "unmarshal_response_body_failed", |
| | } |
| | } |
| | case relayconstant.RelayModeMidjourneyTaskFetchByCondition: |
| | var condition = struct { |
| | IDs []string `json:"ids"` |
| | }{} |
| | err = c.BindJSON(&condition) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "do_request_failed", |
| | } |
| | } |
| | var tasks []dto.MidjourneyDto |
| | if len(condition.IDs) != 0 { |
| | originTasks := model.GetByMJIds(userId, condition.IDs) |
| | for _, originTask := range originTasks { |
| | midjourneyTask := coverMidjourneyTaskDto(c, originTask) |
| | tasks = append(tasks, midjourneyTask) |
| | } |
| | } |
| | if tasks == nil { |
| | tasks = make([]dto.MidjourneyDto, 0) |
| | } |
| | respBody, err = json.Marshal(tasks) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "unmarshal_response_body_failed", |
| | } |
| | } |
| | } |
| |
|
| | c.Writer.Header().Set("Content-Type", "application/json") |
| |
|
| | _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "copy_response_body_failed", |
| | } |
| | } |
| | return nil |
| | } |
| |
|
| | func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { |
| | consumeQuota := true |
| | var midjRequest dto.MidjourneyRequest |
| | err := common.UnmarshalBodyReusable(c, &midjRequest) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") |
| | } |
| |
|
| | relayInfo.InitChannelMeta(c) |
| |
|
| | if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { |
| | mjErr := service.CoverPlusActionToNormalAction(&midjRequest) |
| | if mjErr != nil { |
| | return mjErr |
| | } |
| | relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange |
| | } |
| | if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { |
| | midjRequest.Action = constant.MjActionVideo |
| | } |
| |
|
| | if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { |
| | if midjRequest.Prompt == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") |
| | } |
| | midjRequest.Action = constant.MjActionImagine |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { |
| | midjRequest.Action = constant.MjActionDescribe |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { |
| | midjRequest.Action = constant.MjActionEdits |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { |
| | midjRequest.Action = constant.MjActionShorten |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { |
| | midjRequest.Action = constant.MjActionBlend |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { |
| | midjRequest.Action = constant.MjActionUpload |
| | } else if midjRequest.TaskId != "" { |
| | mjId := "" |
| | if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { |
| | if midjRequest.TaskId == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") |
| | } else if midjRequest.Action == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") |
| | } else if midjRequest.Index == 0 { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required") |
| | } |
| | |
| | mjId = midjRequest.TaskId |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { |
| | if midjRequest.Content == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") |
| | } |
| | params := service.ConvertSimpleChangeParams(midjRequest.Content) |
| | if params == nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") |
| | } |
| | mjId = params.TaskId |
| | midjRequest.Action = params.Action |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { |
| | |
| | |
| | |
| | mjId = midjRequest.TaskId |
| | midjRequest.Action = constant.MjActionModal |
| | } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { |
| | midjRequest.Action = constant.MjActionVideo |
| | if midjRequest.TaskId == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") |
| | } else if midjRequest.Action == "" { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") |
| | } |
| | mjId = midjRequest.TaskId |
| | } |
| |
|
| | originTask := model.GetByMJId(relayInfo.UserId, mjId) |
| | if originTask == nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") |
| | } else { |
| | if setting.MjActionCheckSuccessEnabled { |
| | if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") |
| | } |
| | } |
| | channel, err := model.GetChannelById(originTask.ChannelId, true) |
| | if err != nil { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") |
| | } |
| | if channel.Status != common.ChannelStatusEnabled { |
| | return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") |
| | } |
| | c.Set("base_url", channel.GetBaseURL()) |
| | c.Set("channel_id", originTask.ChannelId) |
| | c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
| | log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) |
| | } |
| | midjRequest.Prompt = originTask.Prompt |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | } |
| |
|
| | if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom { |
| | consumeQuota = false |
| | } |
| |
|
| | |
| | requestURL := getMjRequestPath(c.Request.URL.String()) |
| |
|
| | baseURL := c.GetString("base_url") |
| |
|
| | |
| |
|
| | fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |
| |
|
| | modelName := service.CoverActionToModelName(midjRequest.Action) |
| |
|
| | priceData := helper.ModelPriceHelperPerCall(c, relayInfo) |
| |
|
| | userQuota, err := model.GetUserQuota(relayInfo.UserId, false) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: err.Error(), |
| | } |
| | } |
| |
|
| | if consumeQuota && userQuota-priceData.Quota < 0 { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "quota_not_enough", |
| | } |
| | } |
| |
|
| | midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) |
| | if err != nil { |
| | return &midjResponseWithStatus.Response |
| | } |
| | midjResponse := &midjResponseWithStatus.Response |
| |
|
| | defer func() { |
| | if consumeQuota && midjResponseWithStatus.StatusCode == 200 { |
| | err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) |
| | if err != nil { |
| | common.SysLog("error consuming token remain quota: " + err.Error()) |
| | } |
| | tokenName := c.GetString("token_name") |
| | logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) |
| | other := service.GenerateMjOtherInfo(relayInfo, priceData) |
| | model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ |
| | ChannelId: relayInfo.ChannelId, |
| | ModelName: modelName, |
| | TokenName: tokenName, |
| | Quota: priceData.Quota, |
| | Content: logContent, |
| | TokenId: relayInfo.TokenId, |
| | Group: relayInfo.UsingGroup, |
| | Other: other, |
| | }) |
| | model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) |
| | model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) |
| | } |
| | }() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | midjourneyTask := &model.Midjourney{ |
| | UserId: relayInfo.UserId, |
| | Code: midjResponse.Code, |
| | Action: midjRequest.Action, |
| | MjId: midjResponse.Result, |
| | Prompt: midjRequest.Prompt, |
| | PromptEn: "", |
| | Description: midjResponse.Description, |
| | State: "", |
| | SubmitTime: time.Now().UnixNano() / int64(time.Millisecond), |
| | StartTime: 0, |
| | FinishTime: 0, |
| | ImageUrl: "", |
| | Status: "", |
| | Progress: "0%", |
| | FailReason: "", |
| | ChannelId: c.GetInt("channel_id"), |
| | Quota: priceData.Quota, |
| | } |
| | if midjResponse.Code == 3 { |
| | |
| | channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) |
| | if err != nil { |
| | common.SysLog("get_channel_null: " + err.Error()) |
| | } |
| | if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { |
| | model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") |
| | } |
| | } |
| | if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { |
| | |
| | midjourneyTask.FailReason = midjResponse.Description |
| | consumeQuota = false |
| | } |
| |
|
| | if midjResponse.Code == 21 { |
| | |
| | properties, ok := midjResponse.Properties.(map[string]interface{}) |
| | if ok { |
| | imageUrl, ok1 := properties["imageUrl"].(string) |
| | status, ok2 := properties["status"].(string) |
| | if ok1 && ok2 { |
| | midjourneyTask.ImageUrl = imageUrl |
| | midjourneyTask.Status = status |
| | if status == "SUCCESS" { |
| | midjourneyTask.Progress = "100%" |
| | midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond) |
| | midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond) |
| | midjResponse.Code = 1 |
| | } |
| | } |
| | } |
| | |
| | if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom { |
| | newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) |
| | responseBody = []byte(newBody) |
| | } |
| | } |
| | if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" { |
| | midjourneyTask.Progress = "100%" |
| | midjourneyTask.Status = "SUCCESS" |
| | } |
| | err = midjourneyTask.Insert() |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "insert_midjourney_task_failed", |
| | } |
| | } |
| |
|
| | if midjResponse.Code == 22 { |
| | |
| | newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) |
| | responseBody = []byte(newBody) |
| | } |
| | |
| | bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) |
| |
|
| | |
| | |
| | |
| | c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) |
| |
|
| | _, err = io.Copy(c.Writer, bodyReader) |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "copy_response_body_failed", |
| | } |
| | } |
| | err = bodyReader.Close() |
| | if err != nil { |
| | return &dto.MidjourneyResponse{ |
| | Code: 4, |
| | Description: "close_response_body_failed", |
| | } |
| | } |
| | return nil |
| | } |
| |
|
| | type taskChangeParams struct { |
| | ID string |
| | Action string |
| | Index int |
| | } |
| |
|
| | func getMjRequestPath(path string) string { |
| | requestURL := path |
| | if strings.Contains(requestURL, "/mj-") { |
| | urls := strings.Split(requestURL, "/mj/") |
| | if len(urls) < 2 { |
| | return requestURL |
| | } |
| | requestURL = "/mj/" + urls[1] |
| | } |
| | return requestURL |
| | } |
| |
|