| | package relay |
| |
|
| | import ( |
| | "bytes" |
| | "encoding/json" |
| | "errors" |
| | "fmt" |
| | "github.com/gin-gonic/gin" |
| | "io" |
| | "net/http" |
| | "one-api/common" |
| | "one-api/constant" |
| | "one-api/dto" |
| | "one-api/model" |
| | relaycommon "one-api/relay/common" |
| | relayconstant "one-api/relay/constant" |
| | "one-api/service" |
| | "one-api/setting" |
| | "one-api/setting/operation_setting" |
| | ) |
| |
|
| | |
| | |
| | |
| | func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { |
| | platform := constant.TaskPlatform(c.GetString("platform")) |
| | relayInfo := relaycommon.GenTaskRelayInfo(c) |
| |
|
| | adaptor := GetTaskAdaptor(platform) |
| | if adaptor == nil { |
| | return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) |
| | } |
| | adaptor.Init(relayInfo) |
| | |
| | taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) |
| | if taskErr != nil { |
| | return |
| | } |
| |
|
| | modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) |
| | modelPrice, success := operation_setting.GetModelPrice(modelName, true) |
| | if !success { |
| | defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName] |
| | if !ok { |
| | modelPrice = 0.1 |
| | } else { |
| | modelPrice = defaultPrice |
| | } |
| | } |
| |
|
| | |
| | groupRatio := setting.GetGroupRatio(relayInfo.Group) |
| | ratio := modelPrice * groupRatio |
| | userQuota, err := model.GetUserQuota(relayInfo.UserId, false) |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | quota := int(ratio * common.QuotaPerUnit) |
| | if userQuota-quota < 0 { |
| | taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) |
| | return |
| | } |
| |
|
| | if relayInfo.OriginTaskID != "" { |
| | originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | if !exist { |
| | taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) |
| | return |
| | } |
| | if originTask.ChannelId != relayInfo.ChannelId { |
| | channel, err := model.GetChannelById(originTask.ChannelId, true) |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) |
| | return |
| | } |
| | if channel.Status != common.ChannelStatusEnabled { |
| | return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest) |
| | } |
| | c.Set("base_url", channel.GetBaseURL()) |
| | c.Set("channel_id", originTask.ChannelId) |
| | c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |
| |
|
| | relayInfo.BaseUrl = channel.GetBaseURL() |
| | relayInfo.ChannelId = originTask.ChannelId |
| | } |
| | } |
| |
|
| | |
| | requestBody, err := adaptor.BuildRequestBody(c, relayInfo) |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | |
| | resp, err := adaptor.DoRequest(c, relayInfo, requestBody) |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | |
| | if resp != nil && resp.StatusCode != http.StatusOK { |
| | responseBody, _ := io.ReadAll(resp.Body) |
| | taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode) |
| | return |
| | } |
| |
|
| | defer func() { |
| | |
| | if relayInfo.ConsumeQuota && taskErr == nil { |
| |
|
| | err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) |
| | if err != nil { |
| | common.SysError("error consuming token remain quota: " + err.Error()) |
| | } |
| | if quota != 0 { |
| | tokenName := c.GetString("token_name") |
| | logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) |
| | other := make(map[string]interface{}) |
| | other["model_price"] = modelPrice |
| | other["group_ratio"] = groupRatio |
| | model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, |
| | modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other) |
| | model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) |
| | model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) |
| | } |
| | } |
| | }() |
| |
|
| | taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) |
| | if taskErr != nil { |
| | return |
| | } |
| | relayInfo.ConsumeQuota = true |
| | |
| | task := model.InitTask(constant.TaskPlatformSuno, relayInfo) |
| | task.TaskID = taskID |
| | task.Quota = quota |
| | task.Data = taskData |
| | err = task.Insert() |
| | if err != nil { |
| | taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | return nil |
| | } |
| |
|
| | var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ |
| | relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, |
| | relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, |
| | } |
| |
|
| | func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { |
| | respBuilder, ok := fetchRespBuilders[relayMode] |
| | if !ok { |
| | taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) |
| | } |
| |
|
| | respBody, taskErr := respBuilder(c) |
| | if taskErr != nil { |
| | return taskErr |
| | } |
| |
|
| | c.Writer.Header().Set("Content-Type", "application/json") |
| | _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
| | if err != nil { |
| | taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | return |
| | } |
| |
|
| | func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
| | userId := c.GetInt("id") |
| | var condition = struct { |
| | IDs []any `json:"ids"` |
| | Action string `json:"action"` |
| | }{} |
| | err := c.BindJSON(&condition) |
| | if err != nil { |
| | taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) |
| | return |
| | } |
| | var tasks []any |
| | if len(condition.IDs) > 0 { |
| | taskModels, err := model.GetByTaskIds(userId, condition.IDs) |
| | if err != nil { |
| | taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | for _, task := range taskModels { |
| | tasks = append(tasks, TaskModel2Dto(task)) |
| | } |
| | } else { |
| | tasks = make([]any, 0) |
| | } |
| | respBody, err = json.Marshal(dto.TaskResponse[[]any]{ |
| | Code: "success", |
| | Data: tasks, |
| | }) |
| | return |
| | } |
| |
|
| | func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
| | taskId := c.Param("id") |
| | userId := c.GetInt("id") |
| |
|
| | originTask, exist, err := model.GetByTaskId(userId, taskId) |
| | if err != nil { |
| | taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) |
| | return |
| | } |
| | if !exist { |
| | taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) |
| | return |
| | } |
| |
|
| | respBody, err = json.Marshal(dto.TaskResponse[any]{ |
| | Code: "success", |
| | Data: TaskModel2Dto(originTask), |
| | }) |
| | return |
| | } |
| |
|
| | func TaskModel2Dto(task *model.Task) *dto.TaskDto { |
| | return &dto.TaskDto{ |
| | TaskID: task.TaskID, |
| | Action: task.Action, |
| | Status: string(task.Status), |
| | FailReason: task.FailReason, |
| | SubmitTime: task.SubmitTime, |
| | StartTime: task.StartTime, |
| | FinishTime: task.FinishTime, |
| | Progress: task.Progress, |
| | Data: task.Data, |
| | } |
| | } |
| |
|