| package doubao |
|
|
| import ( |
| "bytes" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net/http" |
|
|
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay/channel" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/service" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/pkg/errors" |
| ) |
|
|
| |
| |
| |
|
|
| type ContentItem struct { |
| Type string `json:"type"` |
| Text string `json:"text,omitempty"` |
| ImageURL *ImageURL `json:"image_url,omitempty"` |
| } |
|
|
| type ImageURL struct { |
| URL string `json:"url"` |
| } |
|
|
| type requestPayload struct { |
| Model string `json:"model"` |
| Content []ContentItem `json:"content"` |
| } |
|
|
| type responsePayload struct { |
| ID string `json:"id"` |
| } |
|
|
| type responseTask struct { |
| ID string `json:"id"` |
| Model string `json:"model"` |
| Status string `json:"status"` |
| Content struct { |
| VideoURL string `json:"video_url"` |
| } `json:"content"` |
| Seed int `json:"seed"` |
| Resolution string `json:"resolution"` |
| Duration int `json:"duration"` |
| Ratio string `json:"ratio"` |
| FramesPerSecond int `json:"framespersecond"` |
| Usage struct { |
| CompletionTokens int `json:"completion_tokens"` |
| TotalTokens int `json:"total_tokens"` |
| } `json:"usage"` |
| CreatedAt int64 `json:"created_at"` |
| UpdatedAt int64 `json:"updated_at"` |
| } |
|
|
| |
| |
| |
|
|
| type TaskAdaptor struct { |
| ChannelType int |
| apiKey string |
| baseURL string |
| } |
|
|
| func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { |
| a.ChannelType = info.ChannelType |
| a.baseURL = info.ChannelBaseUrl |
| a.apiKey = info.ApiKey |
| } |
|
|
| |
| func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { |
| |
| return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) |
| } |
|
|
| |
| func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { |
| return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil |
| } |
|
|
| |
| func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { |
| req.Header.Set("Content-Type", "application/json") |
| req.Header.Set("Accept", "application/json") |
| req.Header.Set("Authorization", "Bearer "+a.apiKey) |
| return nil |
| } |
|
|
| |
| func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { |
| v, exists := c.Get("task_request") |
| if !exists { |
| return nil, fmt.Errorf("request not found in context") |
| } |
| req := v.(relaycommon.TaskSubmitReq) |
|
|
| body, err := a.convertToRequestPayload(&req) |
| if err != nil { |
| return nil, errors.Wrap(err, "convert request payload failed") |
| } |
| data, err := json.Marshal(body) |
| if err != nil { |
| return nil, err |
| } |
| return bytes.NewReader(data), nil |
| } |
|
|
| |
| func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| return channel.DoTaskApiRequest(a, c, info, requestBody) |
| } |
|
|
| |
| func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { |
| responseBody, err := io.ReadAll(resp.Body) |
| if err != nil { |
| taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) |
| return |
| } |
| _ = resp.Body.Close() |
|
|
| |
| var dResp responsePayload |
| if err := json.Unmarshal(responseBody, &dResp); err != nil { |
| taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) |
| return |
| } |
|
|
| if dResp.ID == "" { |
| taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID}) |
| return dResp.ID, responseBody, nil |
| } |
|
|
| |
| func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { |
| taskID, ok := body["task_id"].(string) |
| if !ok { |
| return nil, fmt.Errorf("invalid task_id") |
| } |
|
|
| uri := fmt.Sprintf("%s/api/v3/contents/generations/tasks/%s", baseUrl, taskID) |
|
|
| req, err := http.NewRequest(http.MethodGet, uri, nil) |
| if err != nil { |
| return nil, err |
| } |
|
|
| req.Header.Set("Accept", "application/json") |
| req.Header.Set("Content-Type", "application/json") |
| req.Header.Set("Authorization", "Bearer "+key) |
|
|
| return service.GetHttpClient().Do(req) |
| } |
|
|
| func (a *TaskAdaptor) GetModelList() []string { |
| return ModelList |
| } |
|
|
| func (a *TaskAdaptor) GetChannelName() string { |
| return ChannelName |
| } |
|
|
| func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { |
| r := requestPayload{ |
| Model: req.Model, |
| Content: []ContentItem{}, |
| } |
|
|
| |
| if req.Prompt != "" { |
| r.Content = append(r.Content, ContentItem{ |
| Type: "text", |
| Text: req.Prompt, |
| }) |
| } |
|
|
| |
| if req.HasImage() { |
| for _, imgURL := range req.Images { |
| r.Content = append(r.Content, ContentItem{ |
| Type: "image_url", |
| ImageURL: &ImageURL{ |
| URL: imgURL, |
| }, |
| }) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return &r, nil |
| } |
|
|
| func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { |
| resTask := responseTask{} |
| if err := json.Unmarshal(respBody, &resTask); err != nil { |
| return nil, errors.Wrap(err, "unmarshal task result failed") |
| } |
|
|
| taskResult := relaycommon.TaskInfo{ |
| Code: 0, |
| } |
|
|
| |
| switch resTask.Status { |
| case "pending", "queued": |
| taskResult.Status = model.TaskStatusQueued |
| taskResult.Progress = "10%" |
| case "processing": |
| taskResult.Status = model.TaskStatusInProgress |
| taskResult.Progress = "50%" |
| case "succeeded": |
| taskResult.Status = model.TaskStatusSuccess |
| taskResult.Progress = "100%" |
| taskResult.Url = resTask.Content.VideoURL |
| |
| taskResult.CompletionTokens = resTask.Usage.CompletionTokens |
| taskResult.TotalTokens = resTask.Usage.TotalTokens |
| case "failed": |
| taskResult.Status = model.TaskStatusFailure |
| taskResult.Progress = "100%" |
| taskResult.Reason = "task failed" |
| default: |
| |
| taskResult.Status = model.TaskStatusInProgress |
| taskResult.Progress = "30%" |
| } |
|
|
| return &taskResult, nil |
| } |
|
|