Spaces:
Build error
Build error
| package jimeng | |
| import ( | |
| "bytes" | |
| "crypto/hmac" | |
| "crypto/sha256" | |
| "encoding/hex" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "net/url" | |
| "one-api/model" | |
| "sort" | |
| "strings" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/pkg/errors" | |
| "one-api/constant" | |
| "one-api/dto" | |
| "one-api/relay/channel" | |
| relaycommon "one-api/relay/common" | |
| "one-api/service" | |
| ) | |
| // ============================ | |
| // Request / Response structures | |
| // ============================ | |
| type requestPayload struct { | |
| ReqKey string `json:"req_key"` | |
| BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` | |
| ImageUrls []string `json:"image_urls,omitempty"` | |
| Prompt string `json:"prompt,omitempty"` | |
| Seed int64 `json:"seed"` | |
| AspectRatio string `json:"aspect_ratio"` | |
| Frames int `json:"frames,omitempty"` | |
| } | |
| type responsePayload struct { | |
| Code int `json:"code"` | |
| Message string `json:"message"` | |
| RequestId string `json:"request_id"` | |
| Data struct { | |
| TaskID string `json:"task_id"` | |
| } `json:"data"` | |
| } | |
| type responseTask struct { | |
| Code int `json:"code"` | |
| Data struct { | |
| BinaryDataBase64 []interface{} `json:"binary_data_base64"` | |
| ImageUrls interface{} `json:"image_urls"` | |
| RespData string `json:"resp_data"` | |
| Status string `json:"status"` | |
| VideoUrl string `json:"video_url"` | |
| } `json:"data"` | |
| Message string `json:"message"` | |
| RequestId string `json:"request_id"` | |
| Status int `json:"status"` | |
| TimeElapsed string `json:"time_elapsed"` | |
| } | |
| // ============================ | |
| // Adaptor implementation | |
| // ============================ | |
| type TaskAdaptor struct { | |
| ChannelType int | |
| accessKey string | |
| secretKey string | |
| baseURL string | |
| } | |
| func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { | |
| a.ChannelType = info.ChannelType | |
| a.baseURL = info.ChannelBaseUrl | |
| // apiKey format: "access_key|secret_key" | |
| keyParts := strings.Split(info.ApiKey, "|") | |
| if len(keyParts) == 2 { | |
| a.accessKey = strings.TrimSpace(keyParts[0]) | |
| a.secretKey = strings.TrimSpace(keyParts[1]) | |
| } | |
| } | |
| // ValidateRequestAndSetAction parses body, validates fields and sets default action. | |
| func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { | |
| // Accept only POST /v1/video/generations as "generate" action. | |
| return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) | |
| } | |
| // BuildRequestURL constructs the upstream URL. | |
| func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { | |
| if isNewAPIRelay(info.ApiKey) { | |
| return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil | |
| } | |
| return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil | |
| } | |
| // BuildRequestHeader sets required headers. | |
| func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Accept", "application/json") | |
| if isNewAPIRelay(info.ApiKey) { | |
| req.Header.Set("Authorization", "Bearer "+info.ApiKey) | |
| } else { | |
| return a.signRequest(req, a.accessKey, a.secretKey) | |
| } | |
| return nil | |
| } | |
| // BuildRequestBody converts request into Jimeng specific format. | |
| func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { | |
| v, exists := c.Get("task_request") | |
| if !exists { | |
| return nil, fmt.Errorf("request not found in context") | |
| } | |
| req := v.(relaycommon.TaskSubmitReq) | |
| body, err := a.convertToRequestPayload(&req) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "convert request payload failed") | |
| } | |
| data, err := json.Marshal(body) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return bytes.NewReader(data), nil | |
| } | |
| // DoRequest delegates to common helper. | |
| func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { | |
| return channel.DoTaskApiRequest(a, c, info, requestBody) | |
| } | |
| // DoResponse handles upstream response, returns taskID etc. | |
| func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { | |
| responseBody, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | |
| return | |
| } | |
| _ = resp.Body.Close() | |
| // Parse Jimeng response | |
| var jResp responsePayload | |
| if err := json.Unmarshal(responseBody, &jResp); err != nil { | |
| taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) | |
| return | |
| } | |
| if jResp.Code != 10000 { | |
| taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID}) | |
| return jResp.Data.TaskID, responseBody, nil | |
| } | |
| // FetchTask fetch task status | |
| func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { | |
| taskID, ok := body["task_id"].(string) | |
| if !ok { | |
| return nil, fmt.Errorf("invalid task_id") | |
| } | |
| uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) | |
| if isNewAPIRelay(key) { | |
| uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL) | |
| } | |
| payload := map[string]string{ | |
| "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 | |
| "task_id": taskID, | |
| } | |
| payloadBytes, err := json.Marshal(payload) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "marshal fetch task payload failed") | |
| } | |
| req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes)) | |
| if err != nil { | |
| return nil, err | |
| } | |
| req.Header.Set("Accept", "application/json") | |
| req.Header.Set("Content-Type", "application/json") | |
| if isNewAPIRelay(key) { | |
| req.Header.Set("Authorization", "Bearer "+key) | |
| } else { | |
| keyParts := strings.Split(key, "|") | |
| if len(keyParts) != 2 { | |
| return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") | |
| } | |
| accessKey := strings.TrimSpace(keyParts[0]) | |
| secretKey := strings.TrimSpace(keyParts[1]) | |
| if err := a.signRequest(req, accessKey, secretKey); err != nil { | |
| return nil, errors.Wrap(err, "sign request failed") | |
| } | |
| } | |
| return service.GetHttpClient().Do(req) | |
| } | |
| func (a *TaskAdaptor) GetModelList() []string { | |
| return []string{"jimeng_vgfm_t2v_l20"} | |
| } | |
| func (a *TaskAdaptor) GetChannelName() string { | |
| return "jimeng" | |
| } | |
| func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error { | |
| var bodyBytes []byte | |
| var err error | |
| if req.Body != nil { | |
| bodyBytes, err = io.ReadAll(req.Body) | |
| if err != nil { | |
| return errors.Wrap(err, "read request body failed") | |
| } | |
| _ = req.Body.Close() | |
| req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind | |
| } else { | |
| bodyBytes = []byte{} | |
| } | |
| payloadHash := sha256.Sum256(bodyBytes) | |
| hexPayloadHash := hex.EncodeToString(payloadHash[:]) | |
| t := time.Now().UTC() | |
| xDate := t.Format("20060102T150405Z") | |
| shortDate := t.Format("20060102") | |
| req.Header.Set("Host", req.URL.Host) | |
| req.Header.Set("X-Date", xDate) | |
| req.Header.Set("X-Content-Sha256", hexPayloadHash) | |
| // Sort and encode query parameters to create canonical query string | |
| queryParams := req.URL.Query() | |
| sortedKeys := make([]string, 0, len(queryParams)) | |
| for k := range queryParams { | |
| sortedKeys = append(sortedKeys, k) | |
| } | |
| sort.Strings(sortedKeys) | |
| var queryParts []string | |
| for _, k := range sortedKeys { | |
| values := queryParams[k] | |
| sort.Strings(values) | |
| for _, v := range values { | |
| queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v))) | |
| } | |
| } | |
| canonicalQueryString := strings.Join(queryParts, "&") | |
| headersToSign := map[string]string{ | |
| "host": req.URL.Host, | |
| "x-date": xDate, | |
| "x-content-sha256": hexPayloadHash, | |
| } | |
| if req.Header.Get("Content-Type") != "" { | |
| headersToSign["content-type"] = req.Header.Get("Content-Type") | |
| } | |
| var signedHeaderKeys []string | |
| for k := range headersToSign { | |
| signedHeaderKeys = append(signedHeaderKeys, k) | |
| } | |
| sort.Strings(signedHeaderKeys) | |
| var canonicalHeaders strings.Builder | |
| for _, k := range signedHeaderKeys { | |
| canonicalHeaders.WriteString(k) | |
| canonicalHeaders.WriteString(":") | |
| canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k])) | |
| canonicalHeaders.WriteString("\n") | |
| } | |
| signedHeaders := strings.Join(signedHeaderKeys, ";") | |
| canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", | |
| req.Method, | |
| req.URL.Path, | |
| canonicalQueryString, | |
| canonicalHeaders.String(), | |
| signedHeaders, | |
| hexPayloadHash, | |
| ) | |
| hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest)) | |
| hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:]) | |
| region := "cn-north-1" | |
| serviceName := "cv" | |
| credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName) | |
| stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", | |
| xDate, | |
| credentialScope, | |
| hexHashedCanonicalRequest, | |
| ) | |
| kDate := hmacSHA256([]byte(secretKey), []byte(shortDate)) | |
| kRegion := hmacSHA256(kDate, []byte(region)) | |
| kService := hmacSHA256(kRegion, []byte(serviceName)) | |
| kSigning := hmacSHA256(kService, []byte("request")) | |
| signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign))) | |
| authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", | |
| accessKey, | |
| credentialScope, | |
| signedHeaders, | |
| signature, | |
| ) | |
| req.Header.Set("Authorization", authorization) | |
| return nil | |
| } | |
| func hmacSHA256(key []byte, data []byte) []byte { | |
| h := hmac.New(sha256.New, key) | |
| h.Write(data) | |
| return h.Sum(nil) | |
| } | |
| func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { | |
| r := requestPayload{ | |
| ReqKey: req.Model, | |
| Prompt: req.Prompt, | |
| } | |
| switch req.Duration { | |
| case 10: | |
| r.Frames = 241 // 24*10+1 = 241 | |
| default: | |
| r.Frames = 121 // 24*5+1 = 121 | |
| } | |
| // Handle one-of image_urls or binary_data_base64 | |
| if req.HasImage() { | |
| if strings.HasPrefix(req.Images[0], "http") { | |
| r.ImageUrls = req.Images | |
| } else { | |
| r.BinaryDataBase64 = req.Images | |
| } | |
| } | |
| metadata := req.Metadata | |
| medaBytes, err := json.Marshal(metadata) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "metadata marshal metadata failed") | |
| } | |
| err = json.Unmarshal(medaBytes, &r) | |
| if err != nil { | |
| return nil, errors.Wrap(err, "unmarshal metadata failed") | |
| } | |
| // 即梦视频3.0 ReqKey转换 | |
| // https://www.volcengine.com/docs/85621/1792707 | |
| if strings.Contains(r.ReqKey, "jimeng_v30") { | |
| if len(r.ImageUrls) > 1 { | |
| // 多张图片:首尾帧生成 | |
| r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1) | |
| } else if len(r.ImageUrls) == 1 { | |
| // 单张图片:图生视频 | |
| r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1) | |
| } else { | |
| // 无图片:文生视频 | |
| r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1) | |
| } | |
| } | |
| return &r, nil | |
| } | |
| func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { | |
| resTask := responseTask{} | |
| if err := json.Unmarshal(respBody, &resTask); err != nil { | |
| return nil, errors.Wrap(err, "unmarshal task result failed") | |
| } | |
| taskResult := relaycommon.TaskInfo{} | |
| if resTask.Code == 10000 { | |
| taskResult.Code = 0 | |
| } else { | |
| taskResult.Code = resTask.Code // todo uni code | |
| taskResult.Reason = resTask.Message | |
| taskResult.Status = model.TaskStatusFailure | |
| taskResult.Progress = "100%" | |
| } | |
| switch resTask.Data.Status { | |
| case "in_queue": | |
| taskResult.Status = model.TaskStatusQueued | |
| taskResult.Progress = "10%" | |
| case "done": | |
| taskResult.Status = model.TaskStatusSuccess | |
| taskResult.Progress = "100%" | |
| } | |
| taskResult.Url = resTask.Data.VideoUrl | |
| return &taskResult, nil | |
| } | |
| func isNewAPIRelay(apiKey string) bool { | |
| return strings.HasPrefix(apiKey, "sk-") | |
| } | |