new-api / relay /relay_task.go
liuzhao521
Deploy New API v0.9.25+ (commit b47cf4ef) to HuggingFace Spaces
4674012
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
/*
Task 任务通过平台、Action 区分任务
*/
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
info.InitChannelMeta(c)
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
if info.TaskRelayInfo == nil {
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
}
platform := constant.TaskPlatform(c.GetString("platform"))
if platform == "" {
platform = GetTaskPlatform(c)
}
info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(info)
// get & validate taskRequest 获取并验证文本请求
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
if taskErr != nil {
return
}
modelName := info.OriginModelName
if modelName == "" {
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
// 预扣
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
var ratio float64
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
if hasUserGroupRatio {
ratio = modelPrice * userGroupRatio
} else {
ratio = modelPrice * groupRatio
}
// FIXME: 临时修补,支持任务仅按次计费
if !common.StringsContains(constant.TaskPricePatches, modelName) {
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
}
}
}
}
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
userQuota, err := model.GetUserQuota(info.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
return
}
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
}
}
// build body
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
return
}
// do request
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return
}
// handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
return
}
defer func() {
// release quota
if info.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(info, quota, 0, true)
if err != nil {
common.SysLog("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
//gRatio := groupRatio
//if hasUserGroupRatio {
// gRatio = userGroupRatio
//}
logContent := fmt.Sprintf("操作 %s", info.Action)
// FIXME: 临时修补,支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s,按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
}
other := make(map[string]interface{})
if c != nil && c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
model.UpdateChannelUsedQuota(info.ChannelId, quota)
}
}
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
if taskErr != nil {
return
}
info.ConsumeQuota = true
// insert task
task := model.InitTask(platform, info)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
task.Action = info.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
return
}
return nil
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
respBuilder, ok := fetchRespBuilders[relayMode]
if !ok {
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
}
respBody, taskErr := respBuilder(c)
if taskErr != nil {
return taskErr
}
if len(respBody) == 0 {
respBody = []byte("{\"code\":\"success\",\"data\":null}")
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return
}
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
userId := c.GetInt("id")
var condition = struct {
IDs []any `json:"ids"`
Action string `json:"action"`
}{}
err := c.BindJSON(&condition)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
return
}
var tasks []any
if len(condition.IDs) > 0 {
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
return
}
for _, task := range taskModels {
tasks = append(tasks, TaskModel2Dto(task))
}
} else {
tasks = make([]any, 0)
}
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
return
}
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("task_id")
if taskId == "" {
taskId = c.GetString("task_id")
}
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
func() {
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
if err2 != nil {
return
}
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
return
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return
}
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID,
"action": originTask.Action,
})
if err2 != nil || resp == nil {
return
}
defer resp.Body.Close()
body, err2 := io.ReadAll(resp.Body)
if err2 != nil {
return
}
ti, err2 := adaptor.ParseTaskResult(body)
if err2 == nil && ti != nil {
if ti.Status != "" {
originTask.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
originTask.Progress = ti.Progress
}
if ti.Url != "" {
if strings.HasPrefix(ti.Url, "data:") {
} else {
originTask.FailReason = ti.Url
}
}
_ = originTask.Update()
var raw map[string]any
_ = json.Unmarshal(body, &raw)
format := "mp4"
if respObj, ok := raw["response"].(map[string]any); ok {
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
if v0, ok := vids[0].(map[string]any); ok {
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
if strings.Contains(mt, "mp4") {
format = "mp4"
} else {
format = mt
}
}
}
}
}
status := "processing"
switch originTask.Status {
case model.TaskStatusSuccess:
status = "succeeded"
case model.TaskStatusFailure:
status = "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
status = "queued"
}
if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": status,
"task_id": originTask.TaskID,
"url": originTask.FailReason,
}
respBody, _ = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
}
}
}()
if len(respBody) != 0 {
return
}
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
return
}
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
return
}
respBody = openAIVideoData
return
}
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
if err != nil {
taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
Data: task.Data,
}
}