Spaces:
Build error
Build error
| package controller | |
| import ( | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "one-api/common" | |
| "one-api/constant" | |
| "one-api/dto" | |
| "one-api/logger" | |
| "one-api/model" | |
| "one-api/relay" | |
| "one-api/relay/channel" | |
| relaycommon "one-api/relay/common" | |
| "time" | |
| ) | |
| func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { | |
| for channelId, taskIds := range taskChannelM { | |
| if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { | |
| logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) | |
| } | |
| } | |
| return nil | |
| } | |
| func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { | |
| logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) | |
| if len(taskIds) == 0 { | |
| return nil | |
| } | |
| cacheGetChannel, err := model.CacheGetChannel(channelId) | |
| if err != nil { | |
| errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ | |
| "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), | |
| "status": "FAILURE", | |
| "progress": "100%", | |
| }) | |
| if errUpdate != nil { | |
| common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) | |
| } | |
| return fmt.Errorf("CacheGetChannel failed: %w", err) | |
| } | |
| adaptor := relay.GetTaskAdaptor(platform) | |
| if adaptor == nil { | |
| return fmt.Errorf("video adaptor not found") | |
| } | |
| for _, taskId := range taskIds { | |
| if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { | |
| logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) | |
| } | |
| } | |
| return nil | |
| } | |
| func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { | |
| baseURL := constant.ChannelBaseURLs[channel.Type] | |
| if channel.GetBaseURL() != "" { | |
| baseURL = channel.GetBaseURL() | |
| } | |
| task := taskM[taskId] | |
| if task == nil { | |
| logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) | |
| return fmt.Errorf("task %s not found", taskId) | |
| } | |
| resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ | |
| "task_id": taskId, | |
| "action": task.Action, | |
| }) | |
| if err != nil { | |
| return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) | |
| } | |
| //if resp.StatusCode != http.StatusOK { | |
| //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) | |
| //} | |
| defer resp.Body.Close() | |
| responseBody, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return fmt.Errorf("readAll failed for task %s: %w", taskId, err) | |
| } | |
| taskResult := &relaycommon.TaskInfo{} | |
| // try parse as New API response format | |
| var responseItems dto.TaskResponse[model.Task] | |
| if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { | |
| t := responseItems.Data | |
| taskResult.TaskID = t.TaskID | |
| taskResult.Status = string(t.Status) | |
| taskResult.Url = t.FailReason | |
| taskResult.Progress = t.Progress | |
| taskResult.Reason = t.FailReason | |
| } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { | |
| return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) | |
| } else { | |
| task.Data = redactVideoResponseBody(responseBody) | |
| } | |
| now := time.Now().Unix() | |
| if taskResult.Status == "" { | |
| return fmt.Errorf("task %s status is empty", taskId) | |
| } | |
| task.Status = model.TaskStatus(taskResult.Status) | |
| switch taskResult.Status { | |
| case model.TaskStatusSubmitted: | |
| task.Progress = "10%" | |
| case model.TaskStatusQueued: | |
| task.Progress = "20%" | |
| case model.TaskStatusInProgress: | |
| task.Progress = "30%" | |
| if task.StartTime == 0 { | |
| task.StartTime = now | |
| } | |
| case model.TaskStatusSuccess: | |
| task.Progress = "100%" | |
| if task.FinishTime == 0 { | |
| task.FinishTime = now | |
| } | |
| if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { | |
| task.FailReason = taskResult.Url | |
| } | |
| case model.TaskStatusFailure: | |
| task.Status = model.TaskStatusFailure | |
| task.Progress = "100%" | |
| if task.FinishTime == 0 { | |
| task.FinishTime = now | |
| } | |
| task.FailReason = taskResult.Reason | |
| logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) | |
| quota := task.Quota | |
| if quota != 0 { | |
| if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { | |
| logger.LogError(ctx, "Failed to increase user quota: "+err.Error()) | |
| } | |
| logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) | |
| model.RecordLog(task.UserId, model.LogTypeSystem, logContent) | |
| } | |
| default: | |
| return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) | |
| } | |
| if taskResult.Progress != "" { | |
| task.Progress = taskResult.Progress | |
| } | |
| if err := task.Update(); err != nil { | |
| common.SysLog("UpdateVideoTask task error: " + err.Error()) | |
| } | |
| return nil | |
| } | |
| func redactVideoResponseBody(body []byte) []byte { | |
| var m map[string]any | |
| if err := json.Unmarshal(body, &m); err != nil { | |
| return body | |
| } | |
| resp, _ := m["response"].(map[string]any) | |
| if resp != nil { | |
| delete(resp, "bytesBase64Encoded") | |
| if v, ok := resp["video"].(string); ok { | |
| resp["video"] = truncateBase64(v) | |
| } | |
| if vs, ok := resp["videos"].([]any); ok { | |
| for i := range vs { | |
| if vm, ok := vs[i].(map[string]any); ok { | |
| delete(vm, "bytesBase64Encoded") | |
| } | |
| } | |
| } | |
| } | |
| b, err := json.Marshal(m) | |
| if err != nil { | |
| return body | |
| } | |
| return b | |
| } | |
| func truncateBase64(s string) string { | |
| const maxKeep = 256 | |
| if len(s) <= maxKeep { | |
| return s | |
| } | |
| return s[:maxKeep] + "..." | |
| } | |