|
|
package relay |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"encoding/json" |
|
|
"errors" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/model" |
|
|
"github.com/QuantumNous/new-api/relay/channel" |
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common" |
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant" |
|
|
"github.com/QuantumNous/new-api/service" |
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { |
|
|
info.InitChannelMeta(c) |
|
|
|
|
|
if info.TaskRelayInfo == nil { |
|
|
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} |
|
|
} |
|
|
platform := constant.TaskPlatform(c.GetString("platform")) |
|
|
if platform == "" { |
|
|
platform = GetTaskPlatform(c) |
|
|
} |
|
|
|
|
|
info.InitChannelMeta(c) |
|
|
adaptor := GetTaskAdaptor(platform) |
|
|
if adaptor == nil { |
|
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) |
|
|
} |
|
|
adaptor.Init(info) |
|
|
|
|
|
taskErr = adaptor.ValidateRequestAndSetAction(c, info) |
|
|
if taskErr != nil { |
|
|
return |
|
|
} |
|
|
|
|
|
modelName := info.OriginModelName |
|
|
if modelName == "" { |
|
|
modelName = service.CoverTaskActionToModelName(platform, info.Action) |
|
|
} |
|
|
modelPrice, success := ratio_setting.GetModelPrice(modelName, true) |
|
|
if !success { |
|
|
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] |
|
|
if !ok { |
|
|
modelPrice = 0.1 |
|
|
} else { |
|
|
modelPrice = defaultPrice |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) |
|
|
var ratio float64 |
|
|
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) |
|
|
if hasUserGroupRatio { |
|
|
ratio = modelPrice * userGroupRatio |
|
|
} else { |
|
|
ratio = modelPrice * groupRatio |
|
|
} |
|
|
|
|
|
if !common.StringsContains(constant.TaskPricePatches, modelName) { |
|
|
if len(info.PriceData.OtherRatios) > 0 { |
|
|
for _, ra := range info.PriceData.OtherRatios { |
|
|
if 1.0 != ra { |
|
|
ratio *= ra |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) |
|
|
userQuota, err := model.GetUserQuota(info.UserId, false) |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
quota := int(ratio * common.QuotaPerUnit) |
|
|
if userQuota-quota < 0 { |
|
|
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) |
|
|
return |
|
|
} |
|
|
|
|
|
if info.OriginTaskID != "" { |
|
|
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
if !exist { |
|
|
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
if originTask.ChannelId != info.ChannelId { |
|
|
channel, err := model.GetChannelById(originTask.ChannelId, true) |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest) |
|
|
} |
|
|
c.Set("base_url", channel.GetBaseURL()) |
|
|
c.Set("channel_id", originTask.ChannelId) |
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
|
|
|
|
|
info.ChannelBaseUrl = channel.GetBaseURL() |
|
|
info.ChannelId = originTask.ChannelId |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
requestBody, err := adaptor.BuildRequestBody(c, info) |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
|
|
|
resp, err := adaptor.DoRequest(c, info, requestBody) |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
|
|
|
if resp != nil && resp.StatusCode != http.StatusOK { |
|
|
responseBody, _ := io.ReadAll(resp.Body) |
|
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode) |
|
|
return |
|
|
} |
|
|
|
|
|
defer func() { |
|
|
|
|
|
if info.ConsumeQuota && taskErr == nil { |
|
|
|
|
|
err := service.PostConsumeQuota(info, quota, 0, true) |
|
|
if err != nil { |
|
|
common.SysLog("error consuming token remain quota: " + err.Error()) |
|
|
} |
|
|
if quota != 0 { |
|
|
tokenName := c.GetString("token_name") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logContent := fmt.Sprintf("操作 %s", info.Action) |
|
|
|
|
|
if common.StringsContains(constant.TaskPricePatches, modelName) { |
|
|
logContent = fmt.Sprintf("%s,按次计费", logContent) |
|
|
} else { |
|
|
if len(info.PriceData.OtherRatios) > 0 { |
|
|
var contents []string |
|
|
for key, ra := range info.PriceData.OtherRatios { |
|
|
if 1.0 != ra { |
|
|
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) |
|
|
} |
|
|
} |
|
|
if len(contents) > 0 { |
|
|
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) |
|
|
} |
|
|
} |
|
|
} |
|
|
other := make(map[string]interface{}) |
|
|
if c != nil && c.Request != nil && c.Request.URL != nil { |
|
|
other["request_path"] = c.Request.URL.Path |
|
|
} |
|
|
other["model_price"] = modelPrice |
|
|
other["group_ratio"] = groupRatio |
|
|
if hasUserGroupRatio { |
|
|
other["user_group_ratio"] = userGroupRatio |
|
|
} |
|
|
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ |
|
|
ChannelId: info.ChannelId, |
|
|
ModelName: modelName, |
|
|
TokenName: tokenName, |
|
|
Quota: quota, |
|
|
Content: logContent, |
|
|
TokenId: info.TokenId, |
|
|
Group: info.UsingGroup, |
|
|
Other: other, |
|
|
}) |
|
|
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) |
|
|
model.UpdateChannelUsedQuota(info.ChannelId, quota) |
|
|
} |
|
|
} |
|
|
}() |
|
|
|
|
|
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) |
|
|
if taskErr != nil { |
|
|
return |
|
|
} |
|
|
info.ConsumeQuota = true |
|
|
|
|
|
task := model.InitTask(platform, info) |
|
|
task.TaskID = taskID |
|
|
task.Quota = quota |
|
|
task.Data = taskData |
|
|
task.Action = info.Action |
|
|
err = task.Insert() |
|
|
if err != nil { |
|
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ |
|
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, |
|
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, |
|
|
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder, |
|
|
} |
|
|
|
|
|
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { |
|
|
respBuilder, ok := fetchRespBuilders[relayMode] |
|
|
if !ok { |
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) |
|
|
} |
|
|
|
|
|
respBody, taskErr := respBuilder(c) |
|
|
if taskErr != nil { |
|
|
return taskErr |
|
|
} |
|
|
if len(respBody) == 0 { |
|
|
respBody = []byte("{\"code\":\"success\",\"data\":null}") |
|
|
} |
|
|
|
|
|
c.Writer.Header().Set("Content-Type", "application/json") |
|
|
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
|
|
userId := c.GetInt("id") |
|
|
var condition = struct { |
|
|
IDs []any `json:"ids"` |
|
|
Action string `json:"action"` |
|
|
}{} |
|
|
err := c.BindJSON(&condition) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
var tasks []any |
|
|
if len(condition.IDs) > 0 { |
|
|
taskModels, err := model.GetByTaskIds(userId, condition.IDs) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
for _, task := range taskModels { |
|
|
tasks = append(tasks, TaskModel2Dto(task)) |
|
|
} |
|
|
} else { |
|
|
tasks = make([]any, 0) |
|
|
} |
|
|
respBody, err = json.Marshal(dto.TaskResponse[[]any]{ |
|
|
Code: "success", |
|
|
Data: tasks, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
|
|
taskId := c.Param("id") |
|
|
userId := c.GetInt("id") |
|
|
|
|
|
originTask, exist, err := model.GetByTaskId(userId, taskId) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
if !exist { |
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
respBody, err = json.Marshal(dto.TaskResponse[any]{ |
|
|
Code: "success", |
|
|
Data: TaskModel2Dto(originTask), |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
|
|
taskId := c.Param("task_id") |
|
|
if taskId == "" { |
|
|
taskId = c.GetString("task_id") |
|
|
} |
|
|
userId := c.GetInt("id") |
|
|
|
|
|
originTask, exist, err := model.GetByTaskId(userId, taskId) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
if !exist { |
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
func() { |
|
|
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) |
|
|
if err2 != nil { |
|
|
return |
|
|
} |
|
|
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { |
|
|
return |
|
|
} |
|
|
baseURL := constant.ChannelBaseURLs[channelModel.Type] |
|
|
if channelModel.GetBaseURL() != "" { |
|
|
baseURL = channelModel.GetBaseURL() |
|
|
} |
|
|
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) |
|
|
if adaptor == nil { |
|
|
return |
|
|
} |
|
|
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ |
|
|
"task_id": originTask.TaskID, |
|
|
"action": originTask.Action, |
|
|
}) |
|
|
if err2 != nil || resp == nil { |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
body, err2 := io.ReadAll(resp.Body) |
|
|
if err2 != nil { |
|
|
return |
|
|
} |
|
|
ti, err2 := adaptor.ParseTaskResult(body) |
|
|
if err2 == nil && ti != nil { |
|
|
if ti.Status != "" { |
|
|
originTask.Status = model.TaskStatus(ti.Status) |
|
|
} |
|
|
if ti.Progress != "" { |
|
|
originTask.Progress = ti.Progress |
|
|
} |
|
|
if ti.Url != "" { |
|
|
if strings.HasPrefix(ti.Url, "data:") { |
|
|
} else { |
|
|
originTask.FailReason = ti.Url |
|
|
} |
|
|
} |
|
|
_ = originTask.Update() |
|
|
var raw map[string]any |
|
|
_ = json.Unmarshal(body, &raw) |
|
|
format := "mp4" |
|
|
if respObj, ok := raw["response"].(map[string]any); ok { |
|
|
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { |
|
|
if v0, ok := vids[0].(map[string]any); ok { |
|
|
if mt, ok := v0["mimeType"].(string); ok && mt != "" { |
|
|
if strings.Contains(mt, "mp4") { |
|
|
format = "mp4" |
|
|
} else { |
|
|
format = mt |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
status := "processing" |
|
|
switch originTask.Status { |
|
|
case model.TaskStatusSuccess: |
|
|
status = "succeeded" |
|
|
case model.TaskStatusFailure: |
|
|
status = "failed" |
|
|
case model.TaskStatusQueued, model.TaskStatusSubmitted: |
|
|
status = "queued" |
|
|
} |
|
|
if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { |
|
|
out := map[string]any{ |
|
|
"error": nil, |
|
|
"format": format, |
|
|
"metadata": nil, |
|
|
"status": status, |
|
|
"task_id": originTask.TaskID, |
|
|
"url": originTask.FailReason, |
|
|
} |
|
|
respBody, _ = json.Marshal(dto.TaskResponse[any]{ |
|
|
Code: "success", |
|
|
Data: out, |
|
|
}) |
|
|
} |
|
|
} |
|
|
}() |
|
|
|
|
|
if len(respBody) != 0 { |
|
|
return |
|
|
} |
|
|
|
|
|
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { |
|
|
adaptor := GetTaskAdaptor(originTask.Platform) |
|
|
if adaptor == nil { |
|
|
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok { |
|
|
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
respBody = openAIVideoData |
|
|
return |
|
|
} |
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) |
|
|
return |
|
|
} |
|
|
respBody, err = json.Marshal(dto.TaskResponse[any]{ |
|
|
Code: "success", |
|
|
Data: TaskModel2Dto(originTask), |
|
|
}) |
|
|
if err != nil { |
|
|
taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError) |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto { |
|
|
return &dto.TaskDto{ |
|
|
TaskID: task.TaskID, |
|
|
Action: task.Action, |
|
|
Status: string(task.Status), |
|
|
FailReason: task.FailReason, |
|
|
SubmitTime: task.SubmitTime, |
|
|
StartTime: task.StartTime, |
|
|
FinishTime: task.FinishTime, |
|
|
Progress: task.Progress, |
|
|
Data: task.Data, |
|
|
} |
|
|
} |
|
|
|