|
|
package relay |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"log" |
|
|
"net/http" |
|
|
"strconv" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/model" |
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common" |
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant" |
|
|
"github.com/QuantumNous/new-api/relay/helper" |
|
|
"github.com/QuantumNous/new-api/service" |
|
|
"github.com/QuantumNous/new-api/setting" |
|
|
"github.com/QuantumNous/new-api/setting/system_setting" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
func RelayMidjourneyImage(c *gin.Context) { |
|
|
taskId := c.Param("id") |
|
|
midjourneyTask := model.GetByOnlyMJId(taskId) |
|
|
if midjourneyTask == nil { |
|
|
c.JSON(400, gin.H{ |
|
|
"error": "midjourney_task_not_found", |
|
|
}) |
|
|
return |
|
|
} |
|
|
var httpClient *http.Client |
|
|
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { |
|
|
proxy := channel.GetSetting().Proxy |
|
|
if proxy != "" { |
|
|
if httpClient, err = service.NewProxyHttpClient(proxy); err != nil { |
|
|
c.JSON(400, gin.H{ |
|
|
"error": "proxy_url_invalid", |
|
|
}) |
|
|
return |
|
|
} |
|
|
} |
|
|
} |
|
|
if httpClient == nil { |
|
|
httpClient = service.GetHttpClient() |
|
|
} |
|
|
resp, err := httpClient.Get(midjourneyTask.ImageUrl) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{ |
|
|
"error": "http_get_image_failed", |
|
|
}) |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
if resp.StatusCode != http.StatusOK { |
|
|
responseBody, _ := io.ReadAll(resp.Body) |
|
|
c.JSON(resp.StatusCode, gin.H{ |
|
|
"error": string(responseBody), |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
contentType := resp.Header.Get("Content-Type") |
|
|
if contentType == "" { |
|
|
|
|
|
contentType = "image/jpeg" |
|
|
} |
|
|
|
|
|
c.Writer.Header().Set("Content-Type", contentType) |
|
|
|
|
|
_, err = io.Copy(c.Writer, resp.Body) |
|
|
if err != nil { |
|
|
log.Println("Failed to stream image:", err) |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { |
|
|
var midjRequest dto.MidjourneyDto |
|
|
err := common.UnmarshalBodyReusable(c, &midjRequest) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "bind_request_body_failed", |
|
|
Properties: nil, |
|
|
Result: "", |
|
|
} |
|
|
} |
|
|
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) |
|
|
if midjourneyTask == nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "midjourney_task_not_found", |
|
|
Properties: nil, |
|
|
Result: "", |
|
|
} |
|
|
} |
|
|
midjourneyTask.Progress = midjRequest.Progress |
|
|
midjourneyTask.PromptEn = midjRequest.PromptEn |
|
|
midjourneyTask.State = midjRequest.State |
|
|
midjourneyTask.SubmitTime = midjRequest.SubmitTime |
|
|
midjourneyTask.StartTime = midjRequest.StartTime |
|
|
midjourneyTask.FinishTime = midjRequest.FinishTime |
|
|
midjourneyTask.ImageUrl = midjRequest.ImageUrl |
|
|
midjourneyTask.VideoUrl = midjRequest.VideoUrl |
|
|
videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls) |
|
|
midjourneyTask.VideoUrls = string(videoUrlsStr) |
|
|
midjourneyTask.Status = midjRequest.Status |
|
|
midjourneyTask.FailReason = midjRequest.FailReason |
|
|
err = midjourneyTask.Update() |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "update_midjourney_task_failed", |
|
|
} |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { |
|
|
midjourneyTask.MjId = originTask.MjId |
|
|
midjourneyTask.Progress = originTask.Progress |
|
|
midjourneyTask.PromptEn = originTask.PromptEn |
|
|
midjourneyTask.State = originTask.State |
|
|
midjourneyTask.SubmitTime = originTask.SubmitTime |
|
|
midjourneyTask.StartTime = originTask.StartTime |
|
|
midjourneyTask.FinishTime = originTask.FinishTime |
|
|
midjourneyTask.ImageUrl = "" |
|
|
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled { |
|
|
midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId |
|
|
if originTask.Status != "SUCCESS" { |
|
|
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) |
|
|
} |
|
|
} else { |
|
|
midjourneyTask.ImageUrl = originTask.ImageUrl |
|
|
} |
|
|
if originTask.VideoUrl != "" { |
|
|
midjourneyTask.VideoUrl = originTask.VideoUrl |
|
|
} |
|
|
midjourneyTask.Status = originTask.Status |
|
|
midjourneyTask.FailReason = originTask.FailReason |
|
|
midjourneyTask.Action = originTask.Action |
|
|
midjourneyTask.Description = originTask.Description |
|
|
midjourneyTask.Prompt = originTask.Prompt |
|
|
if originTask.Buttons != "" { |
|
|
var buttons []dto.ActionButton |
|
|
err := json.Unmarshal([]byte(originTask.Buttons), &buttons) |
|
|
if err == nil { |
|
|
midjourneyTask.Buttons = buttons |
|
|
} |
|
|
} |
|
|
if originTask.VideoUrls != "" { |
|
|
var videoUrls []dto.ImgUrls |
|
|
err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls) |
|
|
if err == nil { |
|
|
midjourneyTask.VideoUrls = videoUrls |
|
|
} |
|
|
} |
|
|
if originTask.Properties != "" { |
|
|
var properties dto.Properties |
|
|
err := json.Unmarshal([]byte(originTask.Properties), &properties) |
|
|
if err == nil { |
|
|
midjourneyTask.Properties = &properties |
|
|
} |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { |
|
|
var swapFaceRequest dto.SwapFaceRequest |
|
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") |
|
|
} |
|
|
|
|
|
info.InitChannelMeta(c) |
|
|
|
|
|
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") |
|
|
} |
|
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace) |
|
|
|
|
|
priceData := helper.ModelPriceHelperPerCall(c, info) |
|
|
|
|
|
userQuota, err := model.GetUserQuota(info.UserId, false) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: err.Error(), |
|
|
} |
|
|
} |
|
|
|
|
|
if userQuota-priceData.Quota < 0 { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "quota_not_enough", |
|
|
} |
|
|
} |
|
|
requestURL := getMjRequestPath(c.Request.URL.String()) |
|
|
baseURL := c.GetString("base_url") |
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |
|
|
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) |
|
|
if err != nil { |
|
|
return &mjResp.Response |
|
|
} |
|
|
defer func() { |
|
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { |
|
|
err := service.PostConsumeQuota(info, priceData.Quota, 0, true) |
|
|
if err != nil { |
|
|
common.SysLog("error consuming token remain quota: " + err.Error()) |
|
|
} |
|
|
|
|
|
tokenName := c.GetString("token_name") |
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) |
|
|
other := service.GenerateMjOtherInfo(info, priceData) |
|
|
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ |
|
|
ChannelId: info.ChannelId, |
|
|
ModelName: modelName, |
|
|
TokenName: tokenName, |
|
|
Quota: priceData.Quota, |
|
|
Content: logContent, |
|
|
TokenId: info.TokenId, |
|
|
Group: info.UsingGroup, |
|
|
Other: other, |
|
|
}) |
|
|
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) |
|
|
model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) |
|
|
} |
|
|
}() |
|
|
midjResponse := &mjResp.Response |
|
|
midjourneyTask := &model.Midjourney{ |
|
|
UserId: info.UserId, |
|
|
Code: midjResponse.Code, |
|
|
Action: constant.MjActionSwapFace, |
|
|
MjId: midjResponse.Result, |
|
|
Prompt: "InsightFace", |
|
|
PromptEn: "", |
|
|
Description: midjResponse.Description, |
|
|
State: "", |
|
|
SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), |
|
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond), |
|
|
FinishTime: 0, |
|
|
ImageUrl: "", |
|
|
Status: "", |
|
|
Progress: "0%", |
|
|
FailReason: "", |
|
|
ChannelId: c.GetInt("channel_id"), |
|
|
Quota: priceData.Quota, |
|
|
} |
|
|
err = midjourneyTask.Insert() |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed") |
|
|
} |
|
|
c.Writer.WriteHeader(mjResp.StatusCode) |
|
|
respBody, err := json.Marshal(midjResponse) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") |
|
|
} |
|
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { |
|
|
taskId := c.Param("id") |
|
|
userId := c.GetInt("id") |
|
|
originTask := model.GetByMJId(userId, taskId) |
|
|
if originTask == nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") |
|
|
} |
|
|
channel, err := model.GetChannelById(originTask.ChannelId, true) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") |
|
|
} |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") |
|
|
} |
|
|
c.Set("channel_id", originTask.ChannelId) |
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
|
|
|
|
|
requestURL := getMjRequestPath(c.Request.URL.String()) |
|
|
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) |
|
|
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) |
|
|
if err != nil { |
|
|
return &midjResponseWithStatus.Response |
|
|
} |
|
|
midjResponse := &midjResponseWithStatus.Response |
|
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) |
|
|
respBody, err := json.Marshal(midjResponse) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") |
|
|
} |
|
|
service.IOCopyBytesGracefully(c, nil, respBody) |
|
|
return nil |
|
|
} |
|
|
|
|
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { |
|
|
userId := c.GetInt("id") |
|
|
var err error |
|
|
var respBody []byte |
|
|
switch relayMode { |
|
|
case relayconstant.RelayModeMidjourneyTaskFetch: |
|
|
taskId := c.Param("id") |
|
|
originTask := model.GetByMJId(userId, taskId) |
|
|
if originTask == nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "task_no_found", |
|
|
} |
|
|
} |
|
|
midjourneyTask := coverMidjourneyTaskDto(c, originTask) |
|
|
respBody, err = json.Marshal(midjourneyTask) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "unmarshal_response_body_failed", |
|
|
} |
|
|
} |
|
|
case relayconstant.RelayModeMidjourneyTaskFetchByCondition: |
|
|
var condition = struct { |
|
|
IDs []string `json:"ids"` |
|
|
}{} |
|
|
err = c.BindJSON(&condition) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "do_request_failed", |
|
|
} |
|
|
} |
|
|
var tasks []dto.MidjourneyDto |
|
|
if len(condition.IDs) != 0 { |
|
|
originTasks := model.GetByMJIds(userId, condition.IDs) |
|
|
for _, originTask := range originTasks { |
|
|
midjourneyTask := coverMidjourneyTaskDto(c, originTask) |
|
|
tasks = append(tasks, midjourneyTask) |
|
|
} |
|
|
} |
|
|
if tasks == nil { |
|
|
tasks = make([]dto.MidjourneyDto, 0) |
|
|
} |
|
|
respBody, err = json.Marshal(tasks) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "unmarshal_response_body_failed", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
c.Writer.Header().Set("Content-Type", "application/json") |
|
|
|
|
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "copy_response_body_failed", |
|
|
} |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { |
|
|
consumeQuota := true |
|
|
var midjRequest dto.MidjourneyRequest |
|
|
err := common.UnmarshalBodyReusable(c, &midjRequest) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") |
|
|
} |
|
|
|
|
|
relayInfo.InitChannelMeta(c) |
|
|
|
|
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { |
|
|
mjErr := service.CoverPlusActionToNormalAction(&midjRequest) |
|
|
if mjErr != nil { |
|
|
return mjErr |
|
|
} |
|
|
relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange |
|
|
} |
|
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { |
|
|
midjRequest.Action = constant.MjActionVideo |
|
|
} |
|
|
|
|
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { |
|
|
if midjRequest.Prompt == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") |
|
|
} |
|
|
midjRequest.Action = constant.MjActionImagine |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { |
|
|
midjRequest.Action = constant.MjActionDescribe |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { |
|
|
midjRequest.Action = constant.MjActionEdits |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { |
|
|
midjRequest.Action = constant.MjActionShorten |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { |
|
|
midjRequest.Action = constant.MjActionBlend |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { |
|
|
midjRequest.Action = constant.MjActionUpload |
|
|
} else if midjRequest.TaskId != "" { |
|
|
mjId := "" |
|
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { |
|
|
if midjRequest.TaskId == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") |
|
|
} else if midjRequest.Action == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") |
|
|
} else if midjRequest.Index == 0 { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required") |
|
|
} |
|
|
|
|
|
mjId = midjRequest.TaskId |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { |
|
|
if midjRequest.Content == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") |
|
|
} |
|
|
params := service.ConvertSimpleChangeParams(midjRequest.Content) |
|
|
if params == nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") |
|
|
} |
|
|
mjId = params.TaskId |
|
|
midjRequest.Action = params.Action |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { |
|
|
|
|
|
|
|
|
|
|
|
mjId = midjRequest.TaskId |
|
|
midjRequest.Action = constant.MjActionModal |
|
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { |
|
|
midjRequest.Action = constant.MjActionVideo |
|
|
if midjRequest.TaskId == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") |
|
|
} else if midjRequest.Action == "" { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") |
|
|
} |
|
|
mjId = midjRequest.TaskId |
|
|
} |
|
|
|
|
|
originTask := model.GetByMJId(relayInfo.UserId, mjId) |
|
|
if originTask == nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") |
|
|
} else { |
|
|
if setting.MjActionCheckSuccessEnabled { |
|
|
if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") |
|
|
} |
|
|
} |
|
|
channel, err := model.GetChannelById(originTask.ChannelId, true) |
|
|
if err != nil { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") |
|
|
} |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") |
|
|
} |
|
|
c.Set("base_url", channel.GetBaseURL()) |
|
|
c.Set("channel_id", originTask.ChannelId) |
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
|
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) |
|
|
} |
|
|
midjRequest.Prompt = originTask.Prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom { |
|
|
consumeQuota = false |
|
|
} |
|
|
|
|
|
|
|
|
requestURL := getMjRequestPath(c.Request.URL.String()) |
|
|
|
|
|
baseURL := c.GetString("base_url") |
|
|
|
|
|
|
|
|
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |
|
|
|
|
|
modelName := service.CoverActionToModelName(midjRequest.Action) |
|
|
|
|
|
priceData := helper.ModelPriceHelperPerCall(c, relayInfo) |
|
|
|
|
|
userQuota, err := model.GetUserQuota(relayInfo.UserId, false) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: err.Error(), |
|
|
} |
|
|
} |
|
|
|
|
|
if consumeQuota && userQuota-priceData.Quota < 0 { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "quota_not_enough", |
|
|
} |
|
|
} |
|
|
|
|
|
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) |
|
|
if err != nil { |
|
|
return &midjResponseWithStatus.Response |
|
|
} |
|
|
midjResponse := &midjResponseWithStatus.Response |
|
|
|
|
|
defer func() { |
|
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { |
|
|
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) |
|
|
if err != nil { |
|
|
common.SysLog("error consuming token remain quota: " + err.Error()) |
|
|
} |
|
|
tokenName := c.GetString("token_name") |
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) |
|
|
other := service.GenerateMjOtherInfo(relayInfo, priceData) |
|
|
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ |
|
|
ChannelId: relayInfo.ChannelId, |
|
|
ModelName: modelName, |
|
|
TokenName: tokenName, |
|
|
Quota: priceData.Quota, |
|
|
Content: logContent, |
|
|
TokenId: relayInfo.TokenId, |
|
|
Group: relayInfo.UsingGroup, |
|
|
Other: other, |
|
|
}) |
|
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) |
|
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) |
|
|
} |
|
|
}() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
midjourneyTask := &model.Midjourney{ |
|
|
UserId: relayInfo.UserId, |
|
|
Code: midjResponse.Code, |
|
|
Action: midjRequest.Action, |
|
|
MjId: midjResponse.Result, |
|
|
Prompt: midjRequest.Prompt, |
|
|
PromptEn: "", |
|
|
Description: midjResponse.Description, |
|
|
State: "", |
|
|
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond), |
|
|
StartTime: 0, |
|
|
FinishTime: 0, |
|
|
ImageUrl: "", |
|
|
Status: "", |
|
|
Progress: "0%", |
|
|
FailReason: "", |
|
|
ChannelId: c.GetInt("channel_id"), |
|
|
Quota: priceData.Quota, |
|
|
} |
|
|
if midjResponse.Code == 3 { |
|
|
|
|
|
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) |
|
|
if err != nil { |
|
|
common.SysLog("get_channel_null: " + err.Error()) |
|
|
} |
|
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { |
|
|
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") |
|
|
} |
|
|
} |
|
|
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { |
|
|
|
|
|
midjourneyTask.FailReason = midjResponse.Description |
|
|
consumeQuota = false |
|
|
} |
|
|
|
|
|
if midjResponse.Code == 21 { |
|
|
|
|
|
properties, ok := midjResponse.Properties.(map[string]interface{}) |
|
|
if ok { |
|
|
imageUrl, ok1 := properties["imageUrl"].(string) |
|
|
status, ok2 := properties["status"].(string) |
|
|
if ok1 && ok2 { |
|
|
midjourneyTask.ImageUrl = imageUrl |
|
|
midjourneyTask.Status = status |
|
|
if status == "SUCCESS" { |
|
|
midjourneyTask.Progress = "100%" |
|
|
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond) |
|
|
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond) |
|
|
midjResponse.Code = 1 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom { |
|
|
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) |
|
|
responseBody = []byte(newBody) |
|
|
} |
|
|
} |
|
|
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" { |
|
|
midjourneyTask.Progress = "100%" |
|
|
midjourneyTask.Status = "SUCCESS" |
|
|
} |
|
|
err = midjourneyTask.Insert() |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "insert_midjourney_task_failed", |
|
|
} |
|
|
} |
|
|
|
|
|
if midjResponse.Code == 22 { |
|
|
|
|
|
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) |
|
|
responseBody = []byte(newBody) |
|
|
} |
|
|
|
|
|
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) |
|
|
|
|
|
_, err = io.Copy(c.Writer, bodyReader) |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "copy_response_body_failed", |
|
|
} |
|
|
} |
|
|
err = bodyReader.Close() |
|
|
if err != nil { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: 4, |
|
|
Description: "close_response_body_failed", |
|
|
} |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
type taskChangeParams struct { |
|
|
ID string |
|
|
Action string |
|
|
Index int |
|
|
} |
|
|
|
|
|
func getMjRequestPath(path string) string { |
|
|
requestURL := path |
|
|
if strings.Contains(requestURL, "/mj-") { |
|
|
urls := strings.Split(requestURL, "/mj/") |
|
|
if len(urls) < 2 { |
|
|
return requestURL |
|
|
} |
|
|
requestURL = "/mj/" + urls[1] |
|
|
} |
|
|
return requestURL |
|
|
} |
|
|
|