| package jimeng |
|
|
| import ( |
| "bytes" |
| "crypto/hmac" |
| "crypto/sha256" |
| "encoding/base64" |
| "encoding/hex" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net/http" |
| "net/url" |
| "sort" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/model" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/pkg/errors" |
|
|
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/relay/channel" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/service" |
| ) |
|
|
| |
| |
| |
|
|
| type requestPayload struct { |
| ReqKey string `json:"req_key"` |
| BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` |
| ImageUrls []string `json:"image_urls,omitempty"` |
| Prompt string `json:"prompt,omitempty"` |
| Seed int64 `json:"seed"` |
| AspectRatio string `json:"aspect_ratio"` |
| Frames int `json:"frames,omitempty"` |
| } |
|
|
| type responsePayload struct { |
| Code int `json:"code"` |
| Message string `json:"message"` |
| RequestId string `json:"request_id"` |
| Data struct { |
| TaskID string `json:"task_id"` |
| } `json:"data"` |
| } |
|
|
| type responseTask struct { |
| Code int `json:"code"` |
| Data struct { |
| BinaryDataBase64 []interface{} `json:"binary_data_base64"` |
| ImageUrls interface{} `json:"image_urls"` |
| RespData string `json:"resp_data"` |
| Status string `json:"status"` |
| VideoUrl string `json:"video_url"` |
| } `json:"data"` |
| Message string `json:"message"` |
| RequestId string `json:"request_id"` |
| Status int `json:"status"` |
| TimeElapsed string `json:"time_elapsed"` |
| } |
|
|
| const ( |
| |
| MaxFileSize int64 = 4*1024*1024 + 700*1024 |
| ) |
|
|
| |
| |
| |
|
|
| type TaskAdaptor struct { |
| ChannelType int |
| accessKey string |
| secretKey string |
| baseURL string |
| } |
|
|
| func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { |
| a.ChannelType = info.ChannelType |
| a.baseURL = info.ChannelBaseUrl |
|
|
| |
| keyParts := strings.Split(info.ApiKey, "|") |
| if len(keyParts) == 2 { |
| a.accessKey = strings.TrimSpace(keyParts[0]) |
| a.secretKey = strings.TrimSpace(keyParts[1]) |
| } |
| } |
|
|
| |
| func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { |
| return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) |
| } |
|
|
| |
| func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { |
| if isNewAPIRelay(info.ApiKey) { |
| return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil |
| } |
| return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil |
| } |
|
|
| |
| func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { |
| req.Header.Set("Content-Type", "application/json") |
| req.Header.Set("Accept", "application/json") |
| if isNewAPIRelay(info.ApiKey) { |
| req.Header.Set("Authorization", "Bearer "+info.ApiKey) |
| } else { |
| return a.signRequest(req, a.accessKey, a.secretKey) |
| } |
| return nil |
| } |
|
|
| func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { |
| v, exists := c.Get("task_request") |
| if !exists { |
| return nil, fmt.Errorf("request not found in context") |
| } |
| req, ok := v.(relaycommon.TaskSubmitReq) |
| if !ok { |
| return nil, fmt.Errorf("invalid request type in context") |
| } |
| |
| if mf, err := c.MultipartForm(); err == nil { |
| if files, exists := mf.File["input_reference"]; exists && len(files) > 0 { |
| if len(files) == 1 { |
| info.Action = constant.TaskActionGenerate |
| } else if len(files) > 1 { |
| info.Action = constant.TaskActionFirstTailGenerate |
| } |
|
|
| |
| var images []string |
|
|
| for _, fileHeader := range files { |
| |
| if fileHeader.Size > MaxFileSize { |
| return nil, fmt.Errorf("文件 %s 大小超过限制,最大允许 %d MB", fileHeader.Filename, MaxFileSize/(1024*1024)) |
| } |
|
|
| file, err := fileHeader.Open() |
| if err != nil { |
| continue |
| } |
| fileBytes, err := io.ReadAll(file) |
| file.Close() |
| if err != nil { |
| continue |
| } |
| |
| base64Str := base64.StdEncoding.EncodeToString(fileBytes) |
| images = append(images, base64Str) |
| } |
| req.Images = images |
| } |
| } |
|
|
| body, err := a.convertToRequestPayload(&req) |
| if err != nil { |
| return nil, errors.Wrap(err, "convert request payload failed") |
| } |
| data, err := json.Marshal(body) |
| if err != nil { |
| return nil, err |
| } |
| return bytes.NewReader(data), nil |
| } |
|
|
| |
| func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| return channel.DoTaskApiRequest(a, c, info, requestBody) |
| } |
|
|
| |
| func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { |
| responseBody, err := io.ReadAll(resp.Body) |
| if err != nil { |
| taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) |
| return |
| } |
| _ = resp.Body.Close() |
|
|
| |
| var jResp responsePayload |
| if err := json.Unmarshal(responseBody, &jResp); err != nil { |
| taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) |
| return |
| } |
|
|
| if jResp.Code != 10000 { |
| taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) |
| return |
| } |
|
|
| ov := dto.NewOpenAIVideo() |
| ov.ID = jResp.Data.TaskID |
| ov.TaskID = jResp.Data.TaskID |
| ov.CreatedAt = time.Now().Unix() |
| ov.Model = info.OriginModelName |
| c.JSON(http.StatusOK, ov) |
| return jResp.Data.TaskID, responseBody, nil |
| } |
|
|
| |
| func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { |
| taskID, ok := body["task_id"].(string) |
| if !ok { |
| return nil, fmt.Errorf("invalid task_id") |
| } |
|
|
| uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) |
| if isNewAPIRelay(key) { |
| uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL) |
| } |
| payload := map[string]string{ |
| "req_key": "jimeng_vgfm_t2v_l20", |
| "task_id": taskID, |
| } |
| payloadBytes, err := json.Marshal(payload) |
| if err != nil { |
| return nil, errors.Wrap(err, "marshal fetch task payload failed") |
| } |
|
|
| req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes)) |
| if err != nil { |
| return nil, err |
| } |
|
|
| req.Header.Set("Accept", "application/json") |
| req.Header.Set("Content-Type", "application/json") |
|
|
| if isNewAPIRelay(key) { |
| req.Header.Set("Authorization", "Bearer "+key) |
| } else { |
| keyParts := strings.Split(key, "|") |
| if len(keyParts) != 2 { |
| return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") |
| } |
| accessKey := strings.TrimSpace(keyParts[0]) |
| secretKey := strings.TrimSpace(keyParts[1]) |
|
|
| if err := a.signRequest(req, accessKey, secretKey); err != nil { |
| return nil, errors.Wrap(err, "sign request failed") |
| } |
| } |
| return service.GetHttpClient().Do(req) |
| } |
|
|
| func (a *TaskAdaptor) GetModelList() []string { |
| return []string{"jimeng_vgfm_t2v_l20"} |
| } |
|
|
| func (a *TaskAdaptor) GetChannelName() string { |
| return "jimeng" |
| } |
|
|
| func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error { |
| var bodyBytes []byte |
| var err error |
|
|
| if req.Body != nil { |
| bodyBytes, err = io.ReadAll(req.Body) |
| if err != nil { |
| return errors.Wrap(err, "read request body failed") |
| } |
| _ = req.Body.Close() |
| req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) |
| } else { |
| bodyBytes = []byte{} |
| } |
|
|
| payloadHash := sha256.Sum256(bodyBytes) |
| hexPayloadHash := hex.EncodeToString(payloadHash[:]) |
|
|
| t := time.Now().UTC() |
| xDate := t.Format("20060102T150405Z") |
| shortDate := t.Format("20060102") |
|
|
| req.Header.Set("Host", req.URL.Host) |
| req.Header.Set("X-Date", xDate) |
| req.Header.Set("X-Content-Sha256", hexPayloadHash) |
|
|
| |
| queryParams := req.URL.Query() |
| sortedKeys := make([]string, 0, len(queryParams)) |
| for k := range queryParams { |
| sortedKeys = append(sortedKeys, k) |
| } |
| sort.Strings(sortedKeys) |
| var queryParts []string |
| for _, k := range sortedKeys { |
| values := queryParams[k] |
| sort.Strings(values) |
| for _, v := range values { |
| queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v))) |
| } |
| } |
| canonicalQueryString := strings.Join(queryParts, "&") |
|
|
| headersToSign := map[string]string{ |
| "host": req.URL.Host, |
| "x-date": xDate, |
| "x-content-sha256": hexPayloadHash, |
| } |
| if req.Header.Get("Content-Type") != "" { |
| headersToSign["content-type"] = req.Header.Get("Content-Type") |
| } |
|
|
| var signedHeaderKeys []string |
| for k := range headersToSign { |
| signedHeaderKeys = append(signedHeaderKeys, k) |
| } |
| sort.Strings(signedHeaderKeys) |
|
|
| var canonicalHeaders strings.Builder |
| for _, k := range signedHeaderKeys { |
| canonicalHeaders.WriteString(k) |
| canonicalHeaders.WriteString(":") |
| canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k])) |
| canonicalHeaders.WriteString("\n") |
| } |
| signedHeaders := strings.Join(signedHeaderKeys, ";") |
|
|
| canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", |
| req.Method, |
| req.URL.Path, |
| canonicalQueryString, |
| canonicalHeaders.String(), |
| signedHeaders, |
| hexPayloadHash, |
| ) |
|
|
| hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest)) |
| hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:]) |
|
|
| region := "cn-north-1" |
| serviceName := "cv" |
| credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName) |
| stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", |
| xDate, |
| credentialScope, |
| hexHashedCanonicalRequest, |
| ) |
|
|
| kDate := hmacSHA256([]byte(secretKey), []byte(shortDate)) |
| kRegion := hmacSHA256(kDate, []byte(region)) |
| kService := hmacSHA256(kRegion, []byte(serviceName)) |
| kSigning := hmacSHA256(kService, []byte("request")) |
| signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign))) |
|
|
| authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", |
| accessKey, |
| credentialScope, |
| signedHeaders, |
| signature, |
| ) |
| req.Header.Set("Authorization", authorization) |
| return nil |
| } |
|
|
| func hmacSHA256(key []byte, data []byte) []byte { |
| h := hmac.New(sha256.New, key) |
| h.Write(data) |
| return h.Sum(nil) |
| } |
|
|
| func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { |
| r := requestPayload{ |
| ReqKey: req.Model, |
| Prompt: req.Prompt, |
| } |
|
|
| switch req.Duration { |
| case 10: |
| r.Frames = 241 |
| default: |
| r.Frames = 121 |
| } |
|
|
| |
| if req.HasImage() { |
| if strings.HasPrefix(req.Images[0], "http") { |
| r.ImageUrls = req.Images |
| } else { |
| r.BinaryDataBase64 = req.Images |
| } |
| } |
| metadata := req.Metadata |
| medaBytes, err := json.Marshal(metadata) |
| if err != nil { |
| return nil, errors.Wrap(err, "metadata marshal metadata failed") |
| } |
| err = json.Unmarshal(medaBytes, &r) |
| if err != nil { |
| return nil, errors.Wrap(err, "unmarshal metadata failed") |
| } |
|
|
| |
| |
| if strings.Contains(r.ReqKey, "jimeng_v30") { |
| if r.ReqKey == "jimeng_v30_pro" { |
| |
| r.ReqKey = "jimeng_ti2v_v30_pro" |
| } else if len(req.Images) > 1 { |
| |
| r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p") |
| } else if len(req.Images) == 1 { |
| |
| r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p") |
| } else { |
| |
| r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1) |
| } |
| } |
|
|
| return &r, nil |
| } |
|
|
| func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { |
| resTask := responseTask{} |
| if err := json.Unmarshal(respBody, &resTask); err != nil { |
| return nil, errors.Wrap(err, "unmarshal task result failed") |
| } |
| taskResult := relaycommon.TaskInfo{} |
| if resTask.Code == 10000 { |
| taskResult.Code = 0 |
| } else { |
| taskResult.Code = resTask.Code |
| taskResult.Reason = resTask.Message |
| taskResult.Status = model.TaskStatusFailure |
| taskResult.Progress = "100%" |
| } |
| switch resTask.Data.Status { |
| case "in_queue": |
| taskResult.Status = model.TaskStatusQueued |
| taskResult.Progress = "10%" |
| case "done": |
| taskResult.Status = model.TaskStatusSuccess |
| taskResult.Progress = "100%" |
| } |
| taskResult.Url = resTask.Data.VideoUrl |
| return &taskResult, nil |
| } |
|
|
| func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { |
| var jimengResp responseTask |
| if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil { |
| return nil, errors.Wrap(err, "unmarshal jimeng task data failed") |
| } |
|
|
| openAIVideo := dto.NewOpenAIVideo() |
| openAIVideo.ID = originTask.TaskID |
| openAIVideo.Status = originTask.Status.ToVideoStatus() |
| openAIVideo.SetProgressStr(originTask.Progress) |
| openAIVideo.SetMetadata("url", jimengResp.Data.VideoUrl) |
| openAIVideo.CreatedAt = originTask.CreatedAt |
| openAIVideo.CompletedAt = originTask.UpdatedAt |
|
|
| if jimengResp.Code != 10000 { |
| openAIVideo.Error = &dto.OpenAIVideoError{ |
| Message: jimengResp.Message, |
| Code: fmt.Sprintf("%d", jimengResp.Code), |
| } |
| } |
|
|
| jsonData, _ := common.Marshal(openAIVideo) |
| return jsonData, nil |
| } |
|
|
| func isNewAPIRelay(apiKey string) bool { |
| return strings.HasPrefix(apiKey, "sk-") |
| } |
|
|