| package websocket |
|
|
| import ( |
| "encoding/json" |
| "fmt" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/go-playground/validator/v10" |
| "github.com/gorilla/websocket" |
| "trpc.group/trpc-go/trpc-go/log" |
| |
| ) |
|
|
| const ( |
| |
| maxMessageSize = 512 * 1024 * 1024 |
| pongWait = 120 * time.Second |
| pingPeriod = (pongWait * 8) / 10 |
| writeWait = 60 * time.Second |
| WSMsgTypeRegister = "register" |
| |
| WSMsgTypeDisconnect = "disconnect" |
|
|
| |
| WSMsgTypeLiveStatus = "liveStatus" |
| WSMsgTypePlanUpdate = "planUpdate" |
| WSMsgTypeNewPlanStep = "newPlanStep" |
| WSMsgTypeStatusUpdate = "statusUpdate" |
| WSMsgTypeToolUsed = "toolUsed" |
| WSMsgTypeResultUpdate = "resultUpdate" |
| WSMsgTypeActionLog = "actionLog" |
| WSMsgTypeError = "error" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| type AgentConnection struct { |
| conn *websocket.Conn |
| agentID string |
|
|
| |
| stateMu sync.RWMutex |
| writeMu sync.Mutex |
|
|
| isActive bool |
| } |
|
|
| |
| type AgentManager struct { |
| connections map[string]*AgentConnection |
| mu sync.RWMutex |
| taskManager *TaskManager |
| |
| } |
|
|
| |
| type AgentRegisterContent struct { |
| AgentID string `json:"agent_id" validate:"required"` |
| Hostname string `json:"hostname" validate:"required"` |
| IP string `json:"ip" validate:"required,ip"` |
| Version string `json:"version" validate:"required"` |
| Capabilities []string `json:"capabilities,omitempty"` |
| Meta string `json:"meta,omitempty"` |
| } |
|
|
| |
| type DisconnectContent struct { |
| AgentID string `json:"agent_id" validate:"required"` |
| Reason string `json:"reason,omitempty"` |
| } |
|
|
| |
| var validate *validator.Validate |
|
|
| |
| func init() { |
| validate = validator.New() |
| } |
|
|
| |
| func formatValidationErrors(err error) string { |
| if validationErrors, ok := err.(validator.ValidationErrors); ok { |
| var errorMessages []string |
| for _, fieldError := range validationErrors { |
| fieldName := fieldError.Field() |
| switch fieldError.Tag() { |
| case "required": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("缺少必需字段: %s", fieldName)) |
| case "ip": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 必须是有效的IP地址", fieldName)) |
| case "email": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 必须是有效的邮箱格式", fieldName)) |
| case "url": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 必须是有效的URL", fieldName)) |
| case "min": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 长度不能小于 %s", fieldName, fieldError.Param())) |
| case "max": |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 长度不能大于 %s", fieldName, fieldError.Param())) |
| default: |
| errorMessages = append(errorMessages, |
| fmt.Sprintf("字段 %s 验证失败: %s", fieldName, fieldError.Tag())) |
| } |
| } |
| return fmt.Sprintf("验证失败: %s", strings.Join(errorMessages, "; ")) |
| } |
| return "验证失败" |
| } |
|
|
| |
| |
| func NewAgentManager() *AgentManager { |
| return &AgentManager{ |
| connections: make(map[string]*AgentConnection), |
| |
| } |
| } |
|
|
| |
| |
| func NewAgentConnection(conn *websocket.Conn) *AgentConnection { |
| return &AgentConnection{ |
| conn: conn, |
| |
| isActive: true, |
| } |
| } |
|
|
| |
| func (am *AgentManager) HandleAgentWebSocket() gin.HandlerFunc { |
|
|
| return func(c *gin.Context) { |
| conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) |
| if err != nil { |
| log.Errorf("WebSocket升级失败: error=%v", err) |
| return |
| } |
|
|
| |
| ac := NewAgentConnection(conn) |
| log.Infof("新的Agent连接建立: remoteAddr=%s", conn.RemoteAddr().String()) |
| go ac.handleConnection(am) |
| } |
| } |
|
|
| |
| func (ac *AgentConnection) handleConnection(am *AgentManager) { |
| defer func() { |
| ac.stateMu.RLock() |
| agentID := ac.agentID |
| remoteAddr := ac.conn.RemoteAddr().String() |
| ac.stateMu.RUnlock() |
| ac.cleanup(am) |
| log.Infof("Agent连接处理结束: agentId=%s, remoteAddr=%s", agentID, remoteAddr) |
| }() |
|
|
| |
| ac.conn.SetReadLimit(maxMessageSize) |
| ac.conn.SetPongHandler(func(string) error { |
| ac.conn.SetReadDeadline(time.Now().Add(pongWait)) |
| return nil |
| }) |
|
|
| |
| go ac.writePump() |
|
|
| |
| for { |
| _, message, err := ac.conn.ReadMessage() |
| if err != nil { |
| ac.stateMu.RLock() |
| agentID := ac.agentID |
| ac.stateMu.RUnlock() |
|
|
| if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { |
| log.Errorf("Agent连接异常断开: agentId=%s, error=%v", agentID, err) |
| } else { |
| log.Infof("Agent连接正常断开: agentId=%s, closeCode=%v", agentID, err) |
| } |
| break |
| } |
|
|
| var wsMsg WSMessage |
| if err := json.Unmarshal(message, &wsMsg); err != nil { |
| log.Errorf("Agent消息解析失败: agentId=%s, error=%v", ac.agentID, err) |
| |
| ac.sendError("消息格式错误请检查JSON格式") |
| continue |
| } |
|
|
| |
| if wsMsg.Type == "" { |
| log.Errorf("Agent消息类型为空: agentId=%s", ac.agentID) |
| ac.sendError("消息类型不能为空") |
| continue |
| } |
|
|
| switch wsMsg.Type { |
| case WSMsgTypeRegister: |
| ac.handleRegister(am, wsMsg.Content) |
| case WSMsgTypeDisconnect: |
| |
| ac.handleDisconnect(am, wsMsg.Content) |
| |
| ac.stateMu.RLock() |
| if !ac.isActive { |
| ac.stateMu.RUnlock() |
| return |
| } |
| ac.stateMu.RUnlock() |
| case WSMsgTypeLiveStatus, WSMsgTypePlanUpdate, WSMsgTypeNewPlanStep, WSMsgTypeStatusUpdate, WSMsgTypeToolUsed, WSMsgTypeResultUpdate, WSMsgTypeActionLog, WSMsgTypeError: |
| |
| ac.handleAgentEvent(am, wsMsg.Content, wsMsg.Type) |
| default: |
| log.Warnf("Agent发送未知消息类型: agentId=%s, type=%s", ac.agentID, wsMsg.Type) |
| ac.sendError(fmt.Sprintf("未知的消息类型: %s。支持的类型: register, disconnect, liveStatus, planUpdate, newPlanStep, statusUpdate, toolUsed, resultUpdate, actionLog", wsMsg.Type)) |
| } |
| } |
| } |
|
|
| |
| func (ac *AgentConnection) handleRegister(am *AgentManager, content interface{}) { |
| contentBytes, _ := json.Marshal(content) |
| var rc AgentRegisterContent |
| if err := json.Unmarshal(contentBytes, &rc); err != nil { |
| log.Errorf("Agent注册消息解析失败: error=%v", err) |
| ac.sendError("注册消息格式错误") |
| return |
| } |
|
|
| |
| if err := validate.Struct(rc); err != nil { |
| errorMsg := formatValidationErrors(err) |
| log.Errorf("Agent注册验证失败: agentId=%s, error=%s", rc.AgentID, errorMsg) |
| ac.sendError(errorMsg) |
| return |
| } |
|
|
| |
| am.mu.Lock() |
| if existingConn, exists := am.connections[rc.AgentID]; exists { |
| am.mu.Unlock() |
| log.Warnf("Agent ID已存在,断开旧连接: agentId=%s", rc.AgentID) |
| |
| existingConn.stateMu.Lock() |
| existingConn.isActive = false |
| existingConn.stateMu.Unlock() |
| existingConn.conn.Close() |
| } else { |
| am.mu.Unlock() |
| } |
|
|
| |
| am.mu.Lock() |
| am.connections[rc.AgentID] = ac |
| am.mu.Unlock() |
|
|
| |
| ac.stateMu.Lock() |
| ac.agentID = rc.AgentID |
| ac.isActive = true |
| ac.stateMu.Unlock() |
|
|
| log.Infof("Agent注册成功: agentId=%s, hostname=%s, ip=%s, version=%s", rc.AgentID, rc.Hostname, rc.IP, rc.Version) |
| |
| response := WSMessage{ |
| Type: "register_ack", |
| Content: Response{ |
| Status: 0, |
| Message: "注册成功", |
| }, |
| } |
| ac.conn.WriteJSON(response) |
| } |
|
|
| |
| func (ac *AgentConnection) handleDisconnect(am *AgentManager, content interface{}) { |
| contentBytes, _ := json.Marshal(content) |
| var dc DisconnectContent |
| if err := json.Unmarshal(contentBytes, &dc); err != nil { |
| ac.sendError("断开连接消息格式错误") |
| return |
| } |
|
|
| |
| if err := validate.Struct(dc); err != nil { |
| errorMsg := formatValidationErrors(err) |
| ac.sendError(errorMsg) |
| return |
| } |
|
|
| |
| ac.stateMu.RLock() |
| agentID := ac.agentID |
| ac.stateMu.RUnlock() |
|
|
| if agentID == "" || agentID != dc.AgentID { |
| ac.sendError("断开连接消息身份验证失败") |
| return |
| } |
|
|
| |
| am.mu.Lock() |
| delete(am.connections, agentID) |
| am.mu.Unlock() |
|
|
| |
| response := WSMessage{ |
| Type: "disconnect_ack", |
| Content: Response{ |
| Status: 0, |
| Message: "断开连接成功", |
| }, |
| } |
| ac.conn.WriteJSON(response) |
|
|
| |
| ac.stateMu.Lock() |
| ac.isActive = false |
| ac.stateMu.Unlock() |
| } |
|
|
| |
| func (ac *AgentConnection) writePump() { |
| ticker := time.NewTicker(pingPeriod) |
| defer func() { |
| ticker.Stop() |
| log.Infof("Agent心跳检测已停止: agentId=%s", ac.agentID) |
| }() |
|
|
| log.Infof("Agent心跳检测已启动: agentId=%s, pingPeriod=%v", ac.agentID, pingPeriod) |
|
|
| for range ticker.C { |
| ac.stateMu.RLock() |
| if !ac.isActive { |
| ac.stateMu.RUnlock() |
| log.Infof("Agent连接已标记为非活跃,停止心跳检测: agentId=%s", ac.agentID) |
| return |
| } |
| agentID := ac.agentID |
| ac.stateMu.RUnlock() |
|
|
| |
| ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
|
|
| |
| err := ac.conn.WriteMessage(websocket.PingMessage, nil) |
| if err != nil { |
| log.Warnf("Agent心跳发送失败,准备重试: agentId=%s, error=%v", agentID, err) |
|
|
| |
| time.Sleep(1 * time.Second) |
| ac.stateMu.RLock() |
| if !ac.isActive { |
| ac.stateMu.RUnlock() |
| log.Infof("Agent连接在重试期间已标记为非活跃: agentId=%s", agentID) |
| return |
| } |
| ac.stateMu.RUnlock() |
|
|
| ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
| err = ac.conn.WriteMessage(websocket.PingMessage, nil) |
| if err != nil { |
| log.Errorf("Agent心跳重试失败,连接已失效: agentId=%s, error=%v", agentID, err) |
|
|
| |
| ac.stateMu.Lock() |
| ac.isActive = false |
| ac.stateMu.Unlock() |
|
|
| log.Errorf("Agent连接已标记为失效: agentId=%s, 原因=心跳失败", agentID) |
| return |
| } else { |
| log.Infof("Agent心跳重试成功: agentId=%s", agentID) |
| } |
| } else { |
| log.Debugf("Agent心跳发送成功: agentId=%s", agentID) |
| } |
| } |
| } |
|
|
| |
| func (ac *AgentConnection) cleanup(am *AgentManager) { |
| ac.stateMu.Lock() |
| agentID := ac.agentID |
| wasActive := ac.isActive |
| ac.isActive = false |
| ac.stateMu.Unlock() |
|
|
| log.Infof("开始清理Agent连接: agentId=%s, wasActive=%v", agentID, wasActive) |
|
|
| if agentID != "" { |
| am.mu.Lock() |
| |
| if _, exists := am.connections[agentID]; exists { |
| delete(am.connections, agentID) |
| log.Infof("Agent已从连接管理器中移除: agentId=%s", agentID) |
| } else { |
| log.Warnf("Agent不在连接管理器中,可能已被移除: agentId=%s", agentID) |
| } |
| am.mu.Unlock() |
|
|
| |
| } else { |
| log.Warnf("清理未注册的Agent连接: remoteAddr=%s", ac.conn.RemoteAddr().String()) |
| } |
|
|
| |
| err := ac.conn.Close() |
| if err != nil { |
| log.Warnf("关闭Agent连接时出错: agentId=%s, error=%v", agentID, err) |
| } else { |
| log.Infof("Agent连接已关闭: agentId=%s", agentID) |
| } |
|
|
| log.Infof("Agent连接清理完成: agentId=%s", agentID) |
| } |
|
|
| |
| func (ac *AgentConnection) sendError(message string) { |
| response := WSMessage{ |
| Type: "error", |
| Content: Response{ |
| Status: 1, |
| Message: message, |
| }, |
| } |
|
|
| |
| ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
|
|
| err := ac.conn.WriteJSON(response) |
| if err != nil { |
| |
| ac.stateMu.Lock() |
| ac.isActive = false |
| ac.stateMu.Unlock() |
| } |
| } |
|
|
| |
| func (ac *AgentConnection) handleAgentEvent(am *AgentManager, content interface{}, eventType string) { |
| contentBytes, err := json.Marshal(content) |
| if err != nil { |
| log.Errorf("Agent事件序列化失败: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err) |
| ac.sendError(fmt.Sprintf("%s事件序列化失败: %v", eventType, err)) |
| return |
| } |
|
|
| var eventMessage TaskEventMessage |
| if err := json.Unmarshal(contentBytes, &eventMessage); err != nil { |
| log.Errorf("Agent事件格式错误: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err) |
| ac.sendError(fmt.Sprintf("%s事件格式错误: %v", eventType, err)) |
| return |
| } |
|
|
| |
| if err := validate.Struct(eventMessage); err != nil { |
| errorMsg := formatValidationErrors(err) |
| log.Errorf("Agent事件验证失败: agentId=%s, eventType=%s, error=%s", ac.agentID, eventType, errorMsg) |
| ac.sendError(fmt.Sprintf("%s事件验证失败: %s", eventType, errorMsg)) |
| return |
| } |
|
|
| |
| sessionId := eventMessage.SessionID |
| event := eventMessage.Event |
|
|
| log.Debugf("收到Agent事件: agentId=%s, sessionId=%s, eventType=%s", ac.agentID, sessionId, eventType) |
|
|
| |
| am.mu.RLock() |
| am.taskManager.HandleAgentEvent(sessionId, eventType, event) |
| am.mu.RUnlock() |
| } |
|
|
| |
| func (am *AgentManager) GetAvailableAgents() []*AgentConnection { |
| am.mu.RLock() |
| defer am.mu.RUnlock() |
|
|
| var availableAgents []*AgentConnection |
| for _, conn := range am.connections { |
| conn.stateMu.RLock() |
| if conn.isActive { |
| availableAgents = append(availableAgents, conn) |
| } |
| conn.stateMu.RUnlock() |
| } |
| return availableAgents |
| } |
|
|
| |
| func (am *AgentManager) SetTaskManager(taskManager *TaskManager) { |
| am.mu.Lock() |
| defer am.mu.Unlock() |
| am.taskManager = taskManager |
| } |
|
|