Spaces:
Build error
Build error
| package vidu | |
| import ( | |
| "bytes" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "github.com/gin-gonic/gin" | |
| "one-api/constant" | |
| "one-api/dto" | |
| "one-api/model" | |
| "one-api/relay/channel" | |
| relaycommon "one-api/relay/common" | |
| "one-api/service" | |
| "github.com/pkg/errors" | |
| ) | |
| // ============================ | |
| // Request / Response structures | |
| // ============================ | |
| type requestPayload struct { | |
| Model string `json:"model"` | |
| Images []string `json:"images"` | |
| Prompt string `json:"prompt,omitempty"` | |
| Duration int `json:"duration,omitempty"` | |
| Seed int `json:"seed,omitempty"` | |
| Resolution string `json:"resolution,omitempty"` | |
| MovementAmplitude string `json:"movement_amplitude,omitempty"` | |
| Bgm bool `json:"bgm,omitempty"` | |
| Payload string `json:"payload,omitempty"` | |
| CallbackUrl string `json:"callback_url,omitempty"` | |
| } | |
| type responsePayload struct { | |
| TaskId string `json:"task_id"` | |
| State string `json:"state"` | |
| Model string `json:"model"` | |
| Images []string `json:"images"` | |
| Prompt string `json:"prompt"` | |
| Duration int `json:"duration"` | |
| Seed int `json:"seed"` | |
| Resolution string `json:"resolution"` | |
| Bgm bool `json:"bgm"` | |
| MovementAmplitude string `json:"movement_amplitude"` | |
| Payload string `json:"payload"` | |
| CreatedAt string `json:"created_at"` | |
| } | |
| type taskResultResponse struct { | |
| State string `json:"state"` | |
| ErrCode string `json:"err_code"` | |
| Credits int `json:"credits"` | |
| Payload string `json:"payload"` | |
| Creations []creation `json:"creations"` | |
| } | |
| type creation struct { | |
| ID string `json:"id"` | |
| URL string `json:"url"` | |
| CoverURL string `json:"cover_url"` | |
| } | |
| // ============================ | |
| // Adaptor implementation | |
| // ============================ | |
| type TaskAdaptor struct { | |
| ChannelType int | |
| baseURL string | |
| } | |
| func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { | |
| a.ChannelType = info.ChannelType | |
| a.baseURL = info.ChannelBaseUrl | |
| } | |
| func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { | |
| return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) | |
| } | |
| func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { | |
| v, exists := c.Get("task_request") | |
| if !exists { | |
| return nil, fmt.Errorf("request not found in context") | |
| } | |
| req := v.(relaycommon.TaskSubmitReq) | |
| body, err := a.convertToRequestPayload(&req) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if len(body.Images) == 0 { | |
| c.Set("action", constant.TaskActionTextGenerate) | |
| } | |
| data, err := json.Marshal(body) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return bytes.NewReader(data), nil | |
| } | |
| func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { | |
| var path string | |
| switch info.Action { | |
| case constant.TaskActionGenerate: | |
| path = "/img2video" | |
| case constant.TaskActionFirstTailGenerate: | |
| path = "/start-end2video" | |
| case constant.TaskActionReferenceGenerate: | |
| path = "/reference2video" | |
| default: | |
| path = "/text2video" | |
| } | |
| return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil | |
| } | |
| func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Accept", "application/json") | |
| req.Header.Set("Authorization", "Token "+info.ApiKey) | |
| return nil | |
| } | |
| func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { | |
| if action := c.GetString("action"); action != "" { | |
| info.Action = action | |
| } | |
| return channel.DoTaskApiRequest(a, c, info, requestBody) | |
| } | |
| func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { | |
| responseBody, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | |
| return | |
| } | |
| var vResp responsePayload | |
| err = json.Unmarshal(responseBody, &vResp) | |
| if err != nil { | |
| taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) | |
| return | |
| } | |
| if vResp.State == "failed" { | |
| taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest) | |
| return | |
| } | |
| c.JSON(http.StatusOK, vResp) | |
| return vResp.TaskId, responseBody, nil | |
| } | |
| func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { | |
| taskID, ok := body["task_id"].(string) | |
| if !ok { | |
| return nil, fmt.Errorf("invalid task_id") | |
| } | |
| url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID) | |
| req, err := http.NewRequest(http.MethodGet, url, nil) | |
| if err != nil { | |
| return nil, err | |
| } | |
| req.Header.Set("Accept", "application/json") | |
| req.Header.Set("Authorization", "Token "+key) | |
| return service.GetHttpClient().Do(req) | |
| } | |
| func (a *TaskAdaptor) GetModelList() []string { | |
| return []string{"viduq1", "vidu2.0", "vidu1.5"} | |
| } | |
| func (a *TaskAdaptor) GetChannelName() string { | |
| return "vidu" | |
| } | |
| // ============================ | |
| // helpers | |
| // ============================ | |
| func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { | |
| r := requestPayload{ | |
| Model: defaultString(req.Model, "viduq1"), | |
| Images: req.Images, | |
| Prompt: req.Prompt, | |
| Duration: defaultInt(req.Duration, 5), | |
| Resolution: defaultString(req.Size, "1080p"), | |
| MovementAmplitude: "auto", | |
| Bgm: false, | |
| } | |
| metadata := req.Metadata | |
| medaBytes, err := json.Marshal(metadata) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "metadata marshal metadata failed") | |
| } | |
| err = json.Unmarshal(medaBytes, &r) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "unmarshal metadata failed") | |
| } | |
| return &r, nil | |
| } | |
| func defaultString(value, defaultValue string) string { | |
| if value == "" { | |
| return defaultValue | |
| } | |
| return value | |
| } | |
| func defaultInt(value, defaultValue int) int { | |
| if value == 0 { | |
| return defaultValue | |
| } | |
| return value | |
| } | |
| func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { | |
| taskInfo := &relaycommon.TaskInfo{} | |
| var taskResp taskResultResponse | |
| err := json.Unmarshal(respBody, &taskResp) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "failed to unmarshal response body") | |
| } | |
| state := taskResp.State | |
| switch state { | |
| case "created", "queueing": | |
| taskInfo.Status = model.TaskStatusSubmitted | |
| case "processing": | |
| taskInfo.Status = model.TaskStatusInProgress | |
| case "success": | |
| taskInfo.Status = model.TaskStatusSuccess | |
| if len(taskResp.Creations) > 0 { | |
| taskInfo.Url = taskResp.Creations[0].URL | |
| } | |
| case "failed": | |
| taskInfo.Status = model.TaskStatusFailure | |
| if taskResp.ErrCode != "" { | |
| taskInfo.Reason = taskResp.ErrCode | |
| } | |
| default: | |
| return nil, fmt.Errorf("unknown task state: %s", state) | |
| } | |
| return taskInfo, nil | |
| } | |