package websocket import ( "encoding/json" "fmt" "io" "mime" "mime/multipart" "net/http" "net/url" "os" "path/filepath" "reflect" "strings" "sync" "time" "github.com/Tencent/AI-Infra-Guard/common/agent" "github.com/Tencent/AI-Infra-Guard/pkg/database" "github.com/gin-gonic/gin" "gorm.io/datatypes" "trpc.group/trpc-go/trpc-go/log" ) // 任务管理器相关数据结构 const ( WSMsgTypeTaskAssign = "task_assign" // 任务分配 // 任务状态常量 TaskStatusTodo = "todo" // 待执行 TaskStatusDoing = "doing" // 执行中 TaskStatusDone = "done" // 已完成 TaskStatusError = "error" TaskStatusTerminated = "terminated" // 已终止 ) type TaskManager struct { mu sync.RWMutex tasks map[string]*TaskCreateRequest // sessionId -> 任务请求 agentManager *AgentManager // 新增:引用 AgentManager taskStore *database.TaskStore // 新增:引用 TaskStore modelStore *database.ModelStore // 新增:引用 ModelStore fileConfig *FileUploadConfig // 新增:文件上传配置 sseManager *SSEManager // 新增:SSE管理器 } func NewTaskManager(agentManager *AgentManager, taskStore *database.TaskStore, modelStore *database.ModelStore, fileConfig *FileUploadConfig, sseManager *SSEManager) *TaskManager { if fileConfig == nil { fileConfig = DefaultFileUploadConfig() } if sseManager == nil { sseManager = NewSSEManager() } return &TaskManager{ tasks: make(map[string]*TaskCreateRequest), agentManager: agentManager, // 注入 AgentManager taskStore: taskStore, // 注入 TaskStore modelStore: modelStore, // 注入 ModelStore fileConfig: fileConfig, // 注入文件上传配置 sseManager: sseManager, // 注入SSE管理器 } } // 添加任务 func (tm *TaskManager) AddTask(req *TaskCreateRequest, traceID string) error { log.Infof("开始添加任务: trace_id=%s, sessionId=%s, taskType=%s, username=%s", traceID, req.SessionID, req.Task, req.Username) // 监控相关代码已移除 // 1. 先检查数据库中是否已存在相同的sessionId existingSession, err := tm.taskStore.GetSession(req.SessionID) if err == nil && existingSession != nil { log.Errorf("任务已存在: trace_id=%s, sessionId=%s, username=%s", traceID, req.SessionID, req.Username) return fmt.Errorf("任务已存在,sessionId: %s", req.SessionID) } // 2. 预存任务到数据库(状态为todo,assigned_agent为空) session := &database.Session{ ID: req.SessionID, Username: req.Username, Title: tm.generateTaskTitle(req), TaskType: req.Task, Content: req.Content, Params: mustMarshalJSON(req.Params), Attachments: mustMarshalJSON(req.Attachments), Status: TaskStatusDoing, AssignedAgent: "", // 预存时为空 CountryIsoCode: req.CountryIsoCode, Share: true, } err = tm.taskStore.CreateSession(session) if err != nil { log.Errorf("预存任务到数据库失败: trace_id=%s, sessionId=%s, error=%v", traceID, req.SessionID, err) return fmt.Errorf("预存任务失败: %v", err) } log.Infof("任务预存成功: trace_id=%s, sessionId=%s", traceID, req.SessionID) // 3. 等待SSE连接建立 timeout := 100 * time.Second start := time.Now() for time.Since(start) < timeout { if tm.sseManager.HasConnection(req.SessionID) { break // 连接已建立 } time.Sleep(500 * time.Millisecond) // 每50ms检查一次 } if !tm.sseManager.HasConnection(req.SessionID) { // SSE连接超时,清理预存的任务 tm.cleanupFailedTask(req.SessionID, traceID) log.Errorf("SSE连接建立超时: trace_id=%s, sessionId=%s, username=%s, timeout=%v", traceID, req.SessionID, req.Username, timeout) return fmt.Errorf("SSE连接建立超时,请重试,sessionId: %s", req.SessionID) } // 4. 存储任务到内存(dispatchTask需要从内存中获取任务) tm.mu.Lock() tm.tasks[req.SessionID] = req tm.mu.Unlock() // 5. 尝试分发任务 err = tm.dispatchTask(req.SessionID, traceID) if err != nil { // 分发失败,清理内存和数据库中的预存内容 tm.cleanupFailedTask(req.SessionID, traceID) log.Errorf("任务分发失败: trace_id=%s, sessionId=%s, error=%v", traceID, req.SessionID, err) return fmt.Errorf("任务分发失败: %v", err) } log.Infof("任务添加成功: trace_id=%s, sessionId=%s, taskType=%s", traceID, req.SessionID, req.Task) return nil } // 一键添加任务并执行 func (tm *TaskManager) AddTaskApi(req *TaskCreateRequest) error { // 1. 先检查数据库中是否已存在相同的sessionId existingSession, err := tm.taskStore.GetSession(req.SessionID) if err == nil && existingSession != nil { return fmt.Errorf("任务已存在,sessionId: %s", req.SessionID) } // 2. 预存任务到数据库(状态为todo,assigned_agent为空) session := &database.Session{ ID: req.SessionID, Username: req.Username, Title: tm.generateTaskTitle(req), TaskType: req.Task, Content: req.Content, Params: mustMarshalJSON(req.Params), Attachments: mustMarshalJSON(req.Attachments), Status: TaskStatusTodo, AssignedAgent: "", // 预存时为空 CountryIsoCode: req.CountryIsoCode, Share: true, } err = tm.taskStore.CreateSession(session) if err != nil { return fmt.Errorf("预存任务失败: %v", err) } // 获取可用 Agent(简化:不做额外健康检查) availableAgents := tm.agentManager.GetAvailableAgents() if len(availableAgents) == 0 { return fmt.Errorf("没有可用的Agent") } // 3. 选择 Agent(简单策略:选择第一个,相信GetAvailableAgents的过滤结果) selectedAgent := availableAgents[0] // 4. 更新session的assigned_agent和开始时间 err = tm.taskStore.UpdateSessionAssignedAgent(req.SessionID, selectedAgent.agentID) if err != nil { return fmt.Errorf("无法更新session的assigned_agent") } // 6. 构造任务分配消息 taskMsg := WSMessage{ Type: WSMsgTypeTaskAssign, Content: TaskContent{ SessionID: req.SessionID, TaskType: req.Task, Content: req.Content, Params: req.Params, Attachments: req.Attachments, Timeout: 3600, CountryIsoCode: req.CountryIsoCode, }, } // 7. 直接发送给 Agent(简化:无重试,无额外健康检查) selectedAgent.stateMu.RLock() agentID := selectedAgent.agentID selectedAgent.stateMu.RUnlock() // 设置写超时并直接发送 selectedAgent.conn.SetWriteDeadline(time.Now().Add(writeWait)) err = selectedAgent.conn.WriteJSON(taskMsg) if err != nil { return fmt.Errorf("下发任务给 %s 失败: %v", agentID, err) } log.Infof("任务分发成功: sessionId=%s, agentId=%s", req.SessionID, agentID) return nil } // cleanupFailedTask 清理失败的任务(内存和数据库) func (tm *TaskManager) cleanupFailedTask(sessionId string, traceID string) { log.Infof("开始清理失败任务: trace_id=%s, sessionId=%s", traceID, sessionId) // 清理内存中的任务 tm.mu.Lock() delete(tm.tasks, sessionId) tm.mu.Unlock() // 清理数据库中的预存任务 err := tm.taskStore.DeleteSession(sessionId) if err != nil { log.Errorf("清理数据库中的失败任务失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) } else { log.Infof("失败任务清理完成: trace_id=%s, sessionId=%s", traceID, sessionId) } } // 获取任务 func (tm *TaskManager) GetTask(sessionId string) (*TaskCreateRequest, bool) { tm.mu.RLock() defer tm.mu.RUnlock() task, ok := tm.tasks[sessionId] return task, ok } // 新增:任务分发方法(简化版本,减少死锁风险) func (tm *TaskManager) dispatchTask(sessionId string, traceID string) error { log.Infof("开始分发任务: trace_id=%s, sessionId=%s", traceID, sessionId) // 1. 获取任务 task, exists := tm.GetTask(sessionId) if !exists { log.Errorf("任务不存在: trace_id=%s, sessionId=%s", traceID, sessionId) return fmt.Errorf("任务不存在") } // 2. 获取可用 Agent(简化:不做额外健康检查) availableAgents := tm.agentManager.GetAvailableAgents() if len(availableAgents) == 0 { log.Warnf("没有可用的Agent: trace_id=%s, sessionId=%s", traceID, sessionId) return fmt.Errorf("没有可用的Agent") } log.Infof("找到可用Agent数量: trace_id=%s, sessionId=%s, count=%d", traceID, sessionId, len(availableAgents)) // 3. 选择 Agent(简单策略:选择第一个,相信GetAvailableAgents的过滤结果) selectedAgent := availableAgents[0] // 4. 更新session的assigned_agent和开始时间 err := tm.taskStore.UpdateSessionAssignedAgent(task.SessionID, selectedAgent.agentID) if err != nil { log.Errorf("无法更新session的assigned_agent: trace_id=%s, sessionId=%s, agentId=%s, error=%v", traceID, task.SessionID, selectedAgent.agentID, err) return fmt.Errorf("无法更新session的assigned_agent") } // 5. 处理params中的modelid,获取模型信息 enhancedParams := make(map[string]interface{}) for k, v := range task.Params { enhancedParams[k] = v } addModel := func(modelId string) (*database.ModelParams, error) { model, err := tm.modelStore.GetModel(modelId) if err != nil { // 检查是否是记录不存在的错误 if err.Error() == "record not found" { log.Errorf("模型不存在: trace_id=%s, sessionId=%s, modelID=%s", traceID, sessionId, modelId) return nil, fmt.Errorf("模型ID '%s' 不存在,请检查模型配置", modelId) } log.Errorf("获取模型信息失败: trace_id=%s, sessionId=%s, modelID=%s, error=%v", traceID, sessionId, modelId, err) return nil, fmt.Errorf("获取模型信息失败: %v", err) } // 测试模型是否有效 //ai := models.NewOpenAI(model.Token, model.ModelName, model.BaseURL) //err = ai.Vaild(context.Background()) //if err != nil { // log.Errorf("模型无效: trace_id=%s, sessionId=%s, modelID=%s, error=%v", traceID, sessionId, modelId, err) // return nil, fmt.Errorf("模型无效: %v", err) //} p := database.ModelParams{ Model: model.ModelName, Token: model.Token, BaseUrl: model.BaseURL, Limit: model.Limit, } return &p, nil } if task.Params != nil { if modelID, exists := task.Params["model_id"]; exists { log.Infof("找到模型ID: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, modelID) switch v := modelID.(type) { case string: modelInfo, err := addModel(v) if err != nil { return err } enhancedParams["model"] = modelInfo case []interface{}: modelsList := make([]*database.ModelParams, 0) log.Infof("找到多个模型ID: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, v) for _, vv := range v { vv, ok := vv.(string) if !ok { log.Errorf("无效的模型ID类型: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, vv) continue } modelInfo, err := addModel(vv) if err != nil { return err } modelsList = append(modelsList, modelInfo) } enhancedParams["model"] = modelsList default: log.Errorf("无效的模型ID类型: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, v) } } if evalModelStr, exists := task.Params["eval_model_id"]; exists { evalModelId, ok := evalModelStr.(string) if ok { evalModelInfo, err := addModel(evalModelId) if err != nil { return err } enhancedParams["eval_model"] = evalModelInfo } } } // 6. 构造任务分配消息 taskMsg := WSMessage{ Type: WSMsgTypeTaskAssign, Content: TaskContent{ SessionID: task.SessionID, TaskType: task.Task, Content: task.Content, Params: enhancedParams, Attachments: task.Attachments, Timeout: 3600, CountryIsoCode: task.CountryIsoCode, }, } log.Infof("任务分配消息: trace_id=%s, sessionId=%s, taskMsg=%+v", traceID, sessionId, taskMsg) // 7. 直接发送给 Agent(简化:无重试,无额外健康检查) selectedAgent.stateMu.RLock() agentID := selectedAgent.agentID isActive := selectedAgent.isActive selectedAgent.stateMu.RUnlock() if !isActive { log.Errorf("选中的Agent已不活跃: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, agentID) // 重置assigned_agent tm.taskStore.UpdateSessionAssignedAgent(task.SessionID, "") return fmt.Errorf("选中的Agent已不活跃: %s", agentID) } // 设置写超时并直接发送 selectedAgent.conn.SetWriteDeadline(time.Now().Add(writeWait)) err = selectedAgent.conn.WriteJSON(taskMsg) if err != nil { log.Errorf("下发任务给Agent失败: trace_id=%s, sessionId=%s, agentId=%s, error=%v", traceID, task.SessionID, agentID, err) return fmt.Errorf("下发任务给 %s 失败: %v", agentID, err) } log.Infof("任务分发成功: trace_id=%s, sessionId=%s, agentId=%s", traceID, task.SessionID, agentID) return nil } // HandleAgentEvent 处理来自Agent的事件 func (tm *TaskManager) HandleAgentEvent(sessionId string, eventType string, event interface{}) { log.Debugf("收到Agent事件: sessionId=%s, eventType=%s", sessionId, eventType) // 使用通用事件处理函数 tm.handleEvent(sessionId, eventType, event) // 根据事件类型记录特定日志 switch eventType { case "liveStatus": if convertedEvent, err := convertToStruct(event, &LiveStatusEvent{}); err == nil { if liveStatusEvent, ok := convertedEvent.(*LiveStatusEvent); ok { log.Debugf("liveStatus事件详情: sessionId=%s, text=%s", sessionId, liveStatusEvent.Text) } } case "planUpdate": if convertedEvent, err := convertToStruct(event, &PlanUpdateEvent{}); err == nil { if planUpdateEvent, ok := convertedEvent.(*PlanUpdateEvent); ok { log.Infof("收到计划更新: sessionId=%s, tasks=%d", sessionId, len(planUpdateEvent.Tasks)) } } case "newPlanStep": if convertedEvent, err := convertToStruct(event, &NewPlanStepEvent{}); err == nil { if newPlanStepEvent, ok := convertedEvent.(*NewPlanStepEvent); ok { log.Infof("新计划步骤: sessionId=%s, stepId=%s", sessionId, newPlanStepEvent.StepID) } } case "statusUpdate": if convertedEvent, err := convertToStruct(event, &StatusUpdateEvent{}); err == nil { if statusUpdateEvent, ok := convertedEvent.(*StatusUpdateEvent); ok { log.Infof("状态更新: sessionId=%s, status=%s", sessionId, statusUpdateEvent.AgentStatus) } } case "toolUsed": if convertedEvent, err := convertToStruct(event, &ToolUsedEvent{}); err == nil { if toolUsedEvent, ok := convertedEvent.(*ToolUsedEvent); ok { log.Infof("工具使用: sessionId=%s, tools=%d", sessionId, len(toolUsedEvent.Tools)) } } case "actionLog": if convertedEvent, err := convertToStruct(event, &ActionLogEvent{}); err == nil { if actionLogEvent, ok := convertedEvent.(*ActionLogEvent); ok { log.Debugf("动作日志: sessionId=%s, actionId=%s", sessionId, actionLogEvent.ActionID) } } case "error": log.Errorf("错误事件: sessionId=%s %v", sessionId, event) updates := map[string]interface{}{ "status": "error", } err := tm.taskStore.UpdateSession(sessionId, updates) if err != nil { log.Errorf("更新任务失败: sessionId=%s, error=%v", sessionId, err) } case "resultUpdate": if convertedEvent, err := convertToStruct(event, &ResultUpdateEvent{}); err == nil { if _, ok := convertedEvent.(*ResultUpdateEvent); ok { log.Infof("任务完成: sessionId=%s", sessionId) // 监控相关代码已移除 // 更新任务状态为已完成 err := tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusDone) if err != nil { log.Errorf("更新任务状态为已完成失败: sessionId=%s, error=%v", sessionId, err) } else { log.Infof("任务状态已更新为已完成: sessionId=%s", sessionId) } // 任务完成,可以清理资源 go tm.cleanupTask(sessionId) } } default: log.Debugf("未知事件类型: sessionId=%s, eventType=%s", sessionId, eventType) } } // convertToStruct 将 interface{} 转换为指定的结构体类型 func convertToStruct(data interface{}, target interface{}) (interface{}, error) { // 先序列化为JSON jsonData, err := json.Marshal(data) if err != nil { return nil, err } // 再反序列化为目标结构体 err = json.Unmarshal(jsonData, target) if err != nil { return nil, err } return target, nil } // generateSecureFileName 生成安全的唯一文件名 func generateSecureFileName(originalName string) string { // 获取文件扩展名 ext := filepath.Ext(originalName) // 获取不带扩展名的原始文件名 baseName := strings.TrimSuffix(originalName, ext) // 生成UUID uuid := generateUUID() // 组合:UUID_原始文件名.扩展名 return fmt.Sprintf("%s_%s%s", baseName, uuid, ext) } // generateUUID 生成简单的UUID func generateUUID() string { return fmt.Sprintf("%d_%d", time.Now().UnixNano(), time.Now().Unix()) } // 通用事件处理函数 func (tm *TaskManager) handleEvent(sessionId string, eventType string, event interface{}) { log.Debugf("开始处理事件: sessionId=%s, eventType=%s", sessionId, eventType) // 生成事件ID id := generateEventID() // 获取事件的时间戳 timestamp := getEventTimestamp(event) // 存储事件到数据库 err := tm.taskStore.StoreEvent(id, sessionId, eventType, event, timestamp) if err != nil { log.Errorf("存储事件失败: sessionId=%s, eventType=%s, error=%v", sessionId, eventType, err) return } // 推送事件到SSE err = tm.sseManager.SendEvent(id, sessionId, eventType, event) if err != nil { // 如果是连接不存在的错误,记录为调试信息而不是错误 if strings.Contains(err.Error(), "连接不存在") { log.Debugf("SSE连接已关闭,跳过事件推送: sessionId=%s, eventType=%s", sessionId, eventType) } else { log.Errorf("推送事件到SSE失败: sessionId=%s, eventType=%s, error=%v", sessionId, eventType, err) } return } // 记录日志 log.Debugf("事件处理完成: sessionId=%s, eventType=%s", sessionId, eventType) } // getEventTimestamp 获取事件的时间戳 func getEventTimestamp(event interface{}) int64 { // 使用反射获取Timestamp字段 v := reflect.ValueOf(event) if v.Kind() == reflect.Ptr { v = v.Elem() } if v.Kind() == reflect.Struct { if field := v.FieldByName("Timestamp"); field.IsValid() && field.CanInterface() { if timestamp, ok := field.Interface().(int64); ok { return timestamp } } } // 如果无法获取时间戳,使用当前时间 return time.Now().UnixMilli() } // TerminateTask 终止任务 func (tm *TaskManager) TerminateTask(sessionId string, username string, traceID string) error { log.Infof("开始终止任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) // 检查任务是否存在 session, err := tm.taskStore.GetSession(sessionId) if err != nil { log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) return fmt.Errorf("任务不存在") } // 验证用户权限(只有任务创建者才能终止任务) if session.Username != username { log.Errorf("无权限终止任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) return fmt.Errorf("无权限操作此任务") } // 通知 Agent 终止任务 if session.AssignedAgent != "" { log.Infof("通知Agent终止任务: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, session.AssignedAgent) tm.notifyAgentToTerminate(session.AssignedAgent, sessionId, traceID) } // 发送终止事件给前端 tm.sendTerminationEvent(sessionId, traceID) // 更新任务状态为已终止 err = tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusTerminated) if err != nil { log.Errorf("更新任务状态失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) return fmt.Errorf("更新任务状态失败") } log.Infof("任务终止完成: trace_id=%s, sessionId=%s", traceID, sessionId) // 监控相关代码已移除 // 异步清理任务资源 go tm.cleanupTask(sessionId) return nil } // notifyAgentToTerminate 通知 Agent 终止任务(简化版本) func (tm *TaskManager) notifyAgentToTerminate(agentID string, sessionId string, traceID string) { // 异步通知Agent,避免阻塞 go func() { // 获取 Agent 连接 availableAgents := tm.agentManager.GetAvailableAgents() for _, agent := range availableAgents { agent.stateMu.RLock() currentAgentID := agent.agentID isActive := agent.isActive agent.stateMu.RUnlock() if currentAgentID == agentID && isActive { // 发送终止消息给 Agent terminateMsg := WSMessage{ Type: "terminate", Content: map[string]interface{}{ "session_id": sessionId, "reason": "用户主动终止", }, } // 直接发送,无重试机制 agent.conn.SetWriteDeadline(time.Now().Add(writeWait)) err := agent.conn.WriteJSON(terminateMsg) if err != nil { log.Errorf("发送终止消息给Agent %s失败: %v", agentID, err) } else { log.Infof("终止消息已发送给Agent %s: trace_id=%s, sessionId=%s", agentID, traceID, sessionId) } break } } }() } // sendTerminationEvent 发送终止事件给前端 func (tm *TaskManager) sendTerminationEvent(sessionId string, traceID string) { event := StatusUpdateEvent{ ID: generateEventID(), Type: "statusUpdate", Timestamp: time.Now().UnixMilli(), AgentStatus: "terminated", Brief: "任务已终止", Description: "用户主动终止了任务执行", NoRender: false, } // 使用通用事件处理函数 tm.handleEvent(sessionId, "statusUpdate", event) log.Infof("终止事件已发送: trace_id=%s, sessionId=%s", traceID, sessionId) } // generateEventID 生成事件ID func generateEventID() string { return time.Now().Format("20060102150405") + "_" + fmt.Sprintf("%d", time.Now().UnixNano()) } // UpdateTask 更新任务信息 func (tm *TaskManager) UpdateTask(sessionId string, req *TaskUpdateRequest, username string, traceID string) error { log.Infof("开始更新任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) // 1. 验证任务是否存在 session, err := tm.taskStore.GetSession(sessionId) if err != nil { log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) return fmt.Errorf("任务不存在") } // 2. 验证权限(只有任务创建者才能更新) if session.Username != username { log.Errorf("无权限操作此任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) return fmt.Errorf("无权限操作此任务") } // 3. 更新任务信息 updates := map[string]interface{}{ "title": req.Title, } err = tm.taskStore.UpdateSession(sessionId, updates) if err != nil { log.Errorf("更新任务信息失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) return fmt.Errorf("更新任务信息失败: %v", err) } log.Infof("任务信息更新成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) return nil } // DeleteTask 删除任务 func (tm *TaskManager) DeleteTask(sessionId string, username string, traceID string) error { log.Infof("开始删除任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) // 检查任务是否存在 session, err := tm.taskStore.GetSession(sessionId) if err != nil { log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) return fmt.Errorf("任务不存在") } // 验证用户权限(只有任务创建者才能删除任务) if session.Username != username { log.Errorf("无权限操作此任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) return fmt.Errorf("无权限操作此任务") } // 使用事务删除会话及其所有消息 err = tm.taskStore.DeleteSessionWithMessages(sessionId) if err != nil { log.Errorf("删除任务失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) return fmt.Errorf("删除任务失败: %v", err) } // 删除附件文件 err = tm.deleteSessionAttachments(session) if err != nil { log.Errorf("删除附件文件失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) // 附件删除失败不影响主流程,只记录警告 } // 清理内存中的任务数据 tm.mu.Lock() delete(tm.tasks, sessionId) tm.mu.Unlock() // 关闭SSE连接 tm.CloseSSESession(sessionId) log.Infof("任务删除完成: trace_id=%s, sessionId=%s", traceID, sessionId) return nil } // deleteSessionAttachments 删除会话的附件文件 func (tm *TaskManager) deleteSessionAttachments(session *database.Session) error { if session.Attachments == nil { return nil } var attachmentURLs []string if err := json.Unmarshal(session.Attachments, &attachmentURLs); err != nil { return fmt.Errorf("解析附件URL失败: %v", err) } for _, url := range attachmentURLs { // 从URL中提取文件名 fileName := tm.extractFileNameFromURL(url) if fileName == url { // 如果无法提取文件名,跳过 continue } // 构建完整的文件路径 filePath := filepath.Join(tm.fileConfig.UploadDir, fileName) // 删除文件 if err := os.Remove(filePath); err != nil { if !os.IsNotExist(err) { log.Errorf("删除附件文件失败: %s, error: %v", filePath, err) } } else { log.Debugf("删除附件文件成功: %s", filePath) } } return nil } // UploadFileResult 文件上传结果 type UploadFileResult struct { Filename string `json:"filename"` // 原始文件名 FileURL string `json:"fileUrl"` // 文件访问URL } // UploadFile 上传文件 func (tm *TaskManager) UploadFile(file *multipart.FileHeader, traceID string) (*UploadFileResult, error) { log.Infof("开始文件上传: trace_id=%s, originalName=%s, size=%d", traceID, file.Filename, file.Size) // 保存原始文件名 originalName := file.Filename // 生成安全的唯一文件名 fileName := generateSecureFileName(file.Filename) log.Debugf("生成安全文件名: trace_id=%s, originalName=%s, secureName=%s", traceID, originalName, fileName) // 使用配置的上传目录 uploadDir := tm.fileConfig.UploadDir if err := os.MkdirAll(uploadDir, 0755); err != nil { log.Errorf("创建上传目录失败: trace_id=%s, path=%s, error=%v", traceID, uploadDir, err) return nil, fmt.Errorf("创建上传目录失败: %v", err) } // 创建文件路径 filePath := filepath.Join(uploadDir, fileName) // 保存文件到本地 src, err := file.Open() if err != nil { log.Errorf("打开上传文件失败: trace_id=%s, originalName=%s, error=%v", traceID, originalName, err) return nil, fmt.Errorf("打开文件失败: %v", err) } defer src.Close() dst, err := os.Create(filePath) if err != nil { log.Errorf("创建目标文件失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) return nil, fmt.Errorf("创建文件失败: %v", err) } defer dst.Close() // 复制文件内容并验证 written, err := io.Copy(dst, src) if err != nil { // 清理已创建的文件 os.Remove(filePath) log.Errorf("文件写入失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) return nil, fmt.Errorf("保存文件失败: %v", err) } // 验证写入的文件大小 if written != file.Size { os.Remove(filePath) log.Errorf("文件写入不完整: trace_id=%s, expected=%d, actual=%d, filePath=%s", traceID, file.Size, written, filePath) return nil, fmt.Errorf("文件写入不完整") } // 生成文件访问URL fileURL := tm.fileConfig.GetFileURL(fileName) log.Infof("文件上传成功: trace_id=%s, originalName=%s, secureName=%s, size=%d, fileURL=%s", traceID, originalName, fileName, written, fileURL) return &UploadFileResult{ Filename: originalName, FileURL: fileURL, }, nil } // GetUserTasks 获取指定用户的任务列表,只返回属于该用户的会话,确保用户只能看到自己的任务。 func (tm *TaskManager) GetUserTasks(username string, traceID string) ([]map[string]interface{}, error) { // 从数据库获取用户的任务列表 sessions, err := tm.taskStore.GetUserSessions(username) if err != nil { log.Errorf("获取用户任务列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err) return nil, fmt.Errorf("获取任务列表失败: %v", err) } // 转换为前端需要的格式 var tasks []map[string]interface{} for _, session := range sessions { task := map[string]interface{}{ "sessionId": session.ID, "title": session.Title, "taskType": session.TaskType, "status": session.Status, "countryIsoCode": session.CountryIsoCode, "updatedAt": session.UpdatedAt, // 直接使用时间戳毫秒级 "createdAt": session.CreatedAt, // 任务创建时间 } // 添加完成时间(如果任务已完成) if session.CompletedAt != nil { task["completedAt"] = *session.CompletedAt } else { task["completedAt"] = nil } tasks = append(tasks, task) } return tasks, nil } // GetUserTasksByType 获取指定用户的任务列表,支持可选的任务类型过滤 func (tm *TaskManager) GetUserTasksByType(username string, taskType string, traceID string) ([]map[string]interface{}, error) { // 从数据库获取用户的任务列表(支持类型过滤) sessions, err := tm.taskStore.GetUserSessionsByType(username, taskType) if err != nil { log.Errorf("获取用户任务列表失败: trace_id=%s, username=%s, taskType=%s, error=%v", traceID, username, taskType, err) return nil, fmt.Errorf("获取任务列表失败: %v", err) } // 转换为前端需要的格式 var tasks []map[string]interface{} for _, session := range sessions { task := map[string]interface{}{ "sessionId": session.ID, "title": session.Title, "taskType": session.TaskType, "status": session.Status, "countryIsoCode": session.CountryIsoCode, "updatedAt": session.UpdatedAt, // 直接使用时间戳毫秒级 "createdAt": session.CreatedAt, // 任务创建时间 } // 添加完成时间(如果任务已完成) if session.CompletedAt != nil { task["completedAt"] = *session.CompletedAt } else { task["completedAt"] = nil } tasks = append(tasks, task) } return tasks, nil } // SearchUserTasksSimple 使用简化参数搜索指定用户的任务,支持单个查询关键词和分页 func (tm *TaskManager) SearchUserTasksSimple(username string, searchParams database.SimpleSearchParams, traceID string) ([]map[string]interface{}, error) { log.Infof("开始简化搜索用户任务: trace_id=%s, username=%s, query=%s, taskType=%s", traceID, username, searchParams.Query, searchParams.TaskType) // 验证和设置默认分页参数 if searchParams.Page < 1 { searchParams.Page = 1 } if searchParams.PageSize < 1 { searchParams.PageSize = 10 } if searchParams.PageSize > 100 { searchParams.PageSize = 100 // 限制最大页面大小 } // 从数据库搜索用户的任务列表 sessions, _, err := tm.taskStore.SearchUserSessionsSimple(username, searchParams) if err != nil { log.Errorf("简化搜索用户任务失败: trace_id=%s, username=%s, taskType=%s, error=%v", traceID, username, searchParams.TaskType, err) return nil, fmt.Errorf("搜索任务失败: %v", err) } // 转换为前端需要的格式 var tasks []map[string]interface{} for _, session := range sessions { task := map[string]interface{}{ "sessionId": session.ID, "title": session.Title, "taskType": session.TaskType, "status": session.Status, "countryIsoCode": session.CountryIsoCode, "updatedAt": session.UpdatedAt, // 直接使用时间戳毫秒级 "createdAt": session.CreatedAt, // 任务创建时间 } // 添加完成时间(如果任务已完成) if session.CompletedAt != nil { task["completedAt"] = *session.CompletedAt } else { task["completedAt"] = nil } tasks = append(tasks, task) } return tasks, nil } // generateTaskTitle 生成任务标题(用于任务创建API) func (tm *TaskManager) generateTaskTitle(req *TaskCreateRequest) string { ret := "" var ModelName = "" language := req.CountryIsoCode if language == "" { language = "zh" } // 定义语言相关的文本 var texts struct { // 任务类型标题 aiInfraScan, mcpScan, modelJailbreak, modelRedteamReport, otherTask string // 其他文本 model, prompt, github, sse string } if language == "en" { texts.aiInfraScan = "AI Infra Scan - " texts.mcpScan = "MCP Scan - " texts.modelJailbreak = "LLM Jailbreaking - " texts.modelRedteamReport = "Jailbreak Evaluation - " texts.otherTask = "Other Task - " texts.model = "Model:" texts.prompt = "Prompt:" texts.github = "Github:" texts.sse = "SSE:" } else { texts.aiInfraScan = "AI基础设施扫描 - " texts.mcpScan = "MCP扫描 - " texts.modelJailbreak = "一键越狱任务 - " texts.modelRedteamReport = "大模型安全体检 - " texts.otherTask = "其他任务 - " texts.model = "模型:" texts.prompt = "prompt:" texts.github = "Github:" texts.sse = "SSE:" } if modelID, exists := req.Params["model_id"]; exists { switch v := modelID.(type) { case string: model, err := tm.modelStore.GetModel(v) if err == nil { ModelName = model.ModelName } case []interface{}: modelStr := make([]string, 0) for _, mid := range v { mid, ok := mid.(string) if !ok { continue } model, err := tm.modelStore.GetModel(mid) if err == nil { modelStr = append(modelStr, model.ModelName) } } ModelName = strings.Join(modelStr, ",") } } // 1. AI基础 ip/域名 ,文件形式:取第一行等xx个 // 2. MCP:文件名以文件展示,github取项目名,sse取链接 // 3. 评测:模型名 eg:qwen3模型评测任务 // 4. 一键越狱:模型名+prompt switch req.Task { case agent.TaskTypeAIInfraScan: ret = texts.aiInfraScan if len(req.Attachments) > 0 && req.Attachments[0] != "" { ret += tm.extractFileNameFromURL(req.Attachments[0]) } if req.Content != "" { ret += req.Content } case agent.TaskTypeMcpScan: ret = texts.mcpScan if len(req.Attachments) > 0 && req.Attachments[0] != "" { // 直接调用现有的extractFileNameFromURL方法 ret += tm.extractFileNameFromURL(req.Attachments[0]) } else if strings.Contains(req.Content, "github.com") { ret += texts.github + tm.extractFileNameFromURL(req.Content) } else { ret += texts.sse + req.Content } case agent.TaskTypeModelJailbreak: ret = texts.modelJailbreak + fmt.Sprintf("%s%s, %s%s", texts.model, ModelName, texts.prompt, req.Content) case agent.TaskTypeModelRedteamReport: ret = texts.modelRedteamReport + ModelName default: ret = texts.otherTask + req.Content } // 如果content为空,尝试从附件中提取第一个URL的文件名作为title return ret } // 辅助函数:将interface{}转换为datatypes.JSON func mustMarshalJSON(v interface{}) datatypes.JSON { if v == nil { return datatypes.JSON("{}") } data, err := json.Marshal(v) if err != nil { return datatypes.JSON("{}") } return datatypes.JSON(data) } // EstablishSSEConnection 建立SSE连接 func (tm *TaskManager) EstablishSSEConnection(w http.ResponseWriter, sessionId string, username string, traceID string) error { log.Infof("建立SSE连接: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) err := tm.sseManager.AddConnection(sessionId, username, w) if err != nil { log.Errorf("建立SSE连接失败: trace_id=%s, sessionId=%s, username=%s, error=%v", traceID, sessionId, username, err) } else { log.Infof("SSE连接建立成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) } return err } // CloseSSESession 关闭SSE会话 func (tm *TaskManager) CloseSSESession(sessionId string) { log.Infof("关闭SSE会话: sessionId=%s", sessionId) tm.sseManager.RemoveConnection(sessionId) log.Infof("SSE会话已关闭: sessionId=%s", sessionId) } // 任务完成/中断时的清理 func (tm *TaskManager) cleanupTask(sessionId string) { log.Infof("开始清理任务资源: sessionId=%s", sessionId) // 清理内存中的任务数据 tm.mu.Lock() delete(tm.tasks, sessionId) tm.mu.Unlock() // 注意:SSE连接已在resultUpdate事件处理中立即清理 tm.CloseSSESession(sessionId) log.Infof("任务清理完成: sessionId=%s", sessionId) } // GetTaskDetail 获取任务详情 func (tm *TaskManager) GetTaskDetail(sessionId string, username string, traceID string) (map[string]interface{}, error) { log.Infof("开始获取任务详情: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) // 检查任务是否存在 session, err := tm.taskStore.GetSession(sessionId) if err != nil { log.Errorf("获取任务详情失败: trace_id=%s, sessionId=%s, username=%s, error=%v", traceID, sessionId, username, err) return nil, fmt.Errorf("任务不存在") } // 验证用户权限(只有任务创建者才能查看) if !session.Share && session.Username != username { log.Errorf("无权限访问任务详情: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) return nil, fmt.Errorf("无权限查看此任务") } // 获取任务的所有消息 messages, err := tm.taskStore.GetSessionMessages(sessionId) if err != nil { log.Errorf("获取任务消息失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) return nil, fmt.Errorf("获取任务消息失败: %v", err) } // 处理附件信息 var attachments []map[string]interface{} if session.Attachments != nil { var attachmentURLs []string if err := json.Unmarshal(session.Attachments, &attachmentURLs); err == nil { for _, url := range attachmentURLs { // 从URL中提取文件名 fileName := tm.extractFileNameFromURL(url) attachments = append(attachments, map[string]interface{}{ "filename": fileName, "fileUrl": url, }) } } } // 处理消息列表 var messageList []map[string]interface{} for _, msg := range messages { // 解析事件数据 var eventData map[string]interface{} if err := json.Unmarshal(msg.EventData, &eventData); err != nil { continue } messageList = append(messageList, map[string]interface{}{ "id": msg.ID, "type": msg.Type, "timestamp": msg.Timestamp, "event": eventData, }) } // 处理任务参数 var params map[string]interface{} if session.Params != nil { if err := json.Unmarshal(session.Params, ¶ms); err != nil { log.Warnf("解析任务参数失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) params = make(map[string]interface{}) } } else { params = make(map[string]interface{}) } // 构建返回数据 detail := map[string]interface{}{ "sessionId": session.ID, "title": session.Title, "status": session.Status, "countryIsoCode": session.CountryIsoCode, "createdAt": session.CreatedAt, "content": session.Content, "params": params, "taskType": session.TaskType, "attachments": attachments, "messages": messageList, } if session.Username != username { delete(detail, "attachments") } log.Infof("获取任务详情成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) return detail, nil } // extractFileNameFromURL 从文件URL中提取原始文件名 func (tm *TaskManager) extractFileNameFromURL(url string) string { // 新的文件名格式: UUID_原始文件名.扩展名 if strings.Contains(url, "/") { parts := strings.Split(url, "/") if len(parts) > 0 { fileName := parts[len(parts)-1] // 新的文件名格式: UUID_原始文件名.扩展名 if strings.Contains(fileName, "_") { // 查找第一个下划线,之后的部分是原始文件名 firstUnderscoreIndex := strings.Index(fileName, "_") if firstUnderscoreIndex > 0 { // 返回下划线后的部分作为原始文件名 return fileName[firstUnderscoreIndex+1:] } } // 如果没有下划线,直接返回文件名 return fileName } } return url } // DownloadFile 下载文件 func (tm *TaskManager) DownloadFile(sessionId string, fileUrl string, username string, c *gin.Context, traceID string) error { log.Infof("开始文件下载: trace_id=%s, sessionId=%s, fileUrl=%s, username=%s", traceID, sessionId, fileUrl, username) filename := strings.TrimLeft(fileUrl, "/") filePath, _ := filepath.Abs(filepath.Join(tm.fileConfig.UploadDir, filename)) if !strings.HasPrefix(filePath, tm.fileConfig.UploadDir) { return fmt.Errorf("文件路径不合法") } if _, err := os.Stat(filePath); os.IsNotExist(err) { log.Errorf("本地文件不存在: trace_id=%s, filePath=%s", traceID, filePath) return fmt.Errorf("文件不存在") } fileInfo, err := os.Stat(filePath) if err != nil { log.Errorf("获取文件信息失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) return fmt.Errorf("获取文件信息失败: %v", err) } log.Debugf("文件信息获取成功: trace_id=%s, filePath=%s, size=%d", traceID, filePath, fileInfo.Size()) // 8. 设置响应头 // 获取文件的MIME类型 ext := filepath.Ext(filePath) mimeType := mime.TypeByExtension(ext) if mimeType == "" { mimeType = "application/octet-stream" } // 设置Content-Type c.Header("Content-Type", mimeType) // 设置Content-Disposition,支持中文文件名 // 使用UTF-8编码处理中文文件名 encodedFileName := url.QueryEscape(filepath.Base(filePath)) c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"; filename*=UTF-8''%s", encodedFileName, encodedFileName)) // 设置Content-Length c.Header("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) // 9. 打开文件并流式传输 file, err := os.Open(filePath) if err != nil { log.Errorf("打开文件失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) return fmt.Errorf("打开文件失败: %v", err) } defer file.Close() // 10. 流式传输文件内容 written, err := io.Copy(c.Writer, file) if err != nil { log.Errorf("文件传输失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) return fmt.Errorf("传输文件失败: %v", err) } log.Infof("文件下载成功: trace_id=%s, sessionId=%s, fileName=%s, fileSize=%d, transmittedSize=%d", traceID, sessionId, filePath, fileInfo.Size(), written) return nil }