| package websocket |
|
|
| import ( |
| "encoding/json" |
| "fmt" |
| "net/http" |
| "sync" |
| "time" |
|
|
| "trpc.group/trpc-go/trpc-go/log" |
| ) |
|
|
| |
| type SSEConnection struct { |
| SessionID string |
| Username string |
| Writer http.ResponseWriter |
| Flusher http.Flusher |
| CloseChan chan bool |
| LastPing time.Time |
| } |
|
|
| |
| type SSEManager struct { |
| connections map[string]*SSEConnection |
| mutex sync.RWMutex |
| } |
|
|
| |
| func NewSSEManager() *SSEManager { |
| return &SSEManager{ |
| connections: make(map[string]*SSEConnection), |
| } |
| } |
|
|
| |
| func (sm *SSEManager) AddConnection(sessionID, username string, w http.ResponseWriter) error { |
| sm.mutex.Lock() |
| defer sm.mutex.Unlock() |
|
|
| |
| if existing, exists := sm.connections[sessionID]; exists { |
| |
| close(existing.CloseChan) |
| log.Infof("SSE连接冲突,关闭现有连接: sessionId=%s, username=%s", sessionID, username) |
| } |
|
|
| |
| flusher, ok := w.(http.Flusher) |
| if !ok { |
| log.Errorf("SSE流式传输不支持: sessionId=%s, username=%s", sessionID, username) |
| return fmt.Errorf("streaming unsupported") |
| } |
|
|
| |
| w.Header().Set("Content-Type", "text/event-stream") |
| w.Header().Set("Cache-Control", "no-cache") |
| w.Header().Set("Connection", "keep-alive") |
| w.Header().Set("Access-Control-Allow-Origin", "*") |
| w.Header().Set("Access-Control-Allow-Headers", "Cache-Control") |
|
|
| |
| conn := &SSEConnection{ |
| SessionID: sessionID, |
| Username: username, |
| Writer: w, |
| Flusher: flusher, |
| CloseChan: make(chan bool), |
| LastPing: time.Now(), |
| } |
|
|
| sm.connections[sessionID] = conn |
| log.Infof("SSE连接建立: sessionId=%s, username=%s, totalConnections=%d", sessionID, username, len(sm.connections)) |
|
|
| |
| sm.sendEventToConnection(conn, "connected", "connected", map[string]interface{}{ |
| "message": "SSE连接已建立", |
| "sessionId": sessionID, |
| }) |
|
|
| |
| go sm.keepConnectionAlive(conn) |
|
|
| return nil |
| } |
|
|
| |
| func (sm *SSEManager) keepConnectionAlive(conn *SSEConnection) { |
| ticker := time.NewTicker(10 * time.Second) |
| defer ticker.Stop() |
|
|
| log.Debugf("SSE心跳启动: sessionId=%s, username=%s", conn.SessionID, conn.Username) |
|
|
| for { |
| select { |
| case <-conn.CloseChan: |
| log.Infof("SSE连接已关闭: sessionId=%s", conn.SessionID) |
| log.Infof("SSE连接关闭: sessionId=%s, username=%s", conn.SessionID, conn.Username) |
| return |
| case <-ticker.C: |
| |
| heartbeat := TaskEventMessage{ |
| ID: fmt.Sprintf("heartbeat_%d", time.Now().Unix()), |
| Type: "liveStatus", |
| SessionID: conn.SessionID, |
| Timestamp: time.Now().Unix(), |
| Event: LiveStatusEvent{ |
| ID: fmt.Sprintf("heartbeat_%d", time.Now().Unix()), |
| Type: "liveStatus", |
| Timestamp: time.Now().UnixMilli(), |
| Text: "思考中...", |
| }, |
| } |
|
|
| eventData, err := json.Marshal(heartbeat) |
| if err != nil { |
| log.Errorf("SSE心跳序列化失败: sessionId=%s, error=%v", conn.SessionID, err) |
| continue |
| } |
|
|
| _, err = fmt.Fprintf(conn.Writer, "data: %s\n\n", eventData) |
| if err != nil { |
| log.Errorf("SSE心跳发送失败: sessionId=%s, error=%v", conn.SessionID, err) |
| sm.RemoveConnection(conn.SessionID) |
| return |
| } |
|
|
| conn.Flusher.Flush() |
| conn.LastPing = time.Now() |
| log.Debugf("SSE心跳发送成功: sessionId=%s", conn.SessionID) |
| } |
| } |
| } |
|
|
| |
| func (sm *SSEManager) RemoveConnection(sessionID string) { |
| sm.mutex.Lock() |
| defer sm.mutex.Unlock() |
|
|
| if conn, exists := sm.connections[sessionID]; exists { |
| close(conn.CloseChan) |
| delete(sm.connections, sessionID) |
| log.Infof("SSE连接移除: sessionId=%s, username=%s, remainingConnections=%d", sessionID, conn.Username, len(sm.connections)) |
| } |
| } |
|
|
| |
| func (sm *SSEManager) SendEvent(id string, sessionID string, eventType string, event interface{}) error { |
| sm.mutex.RLock() |
| conn, exists := sm.connections[sessionID] |
| sm.mutex.RUnlock() |
|
|
| if !exists { |
| log.Warnf("SSE连接不存在,跳过事件推送: sessionId=%s, eventType=%s", sessionID, eventType) |
| return fmt.Errorf("连接不存在: sessionId=%s", sessionID) |
| } |
|
|
| log.Debugf("SSE事件推送: sessionId=%s, eventType=%s, eventId=%s", sessionID, eventType, id) |
| return sm.sendEventToConnection(conn, id, eventType, event) |
| } |
|
|
| |
| func (sm *SSEManager) sendEventToConnection(conn *SSEConnection, id string, eventType string, event interface{}) error { |
| |
| eventMessage := TaskEventMessage{ |
| ID: id, |
| Type: eventType, |
| SessionID: conn.SessionID, |
| Timestamp: time.Now().Unix(), |
| Event: event, |
| } |
|
|
| |
| eventData, err := json.Marshal(eventMessage) |
| if err != nil { |
| log.Errorf("SSE事件序列化失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err) |
| return fmt.Errorf("序列化事件失败: %v", err) |
| } |
|
|
| |
| |
| _, err = fmt.Fprintf(conn.Writer, "id: %s\nevent: %s\ndata: %s\n\n", |
| id, eventType, eventData) |
| if err != nil { |
| log.Errorf("SSE事件发送失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err) |
| return fmt.Errorf("发送事件失败: %v", err) |
| } |
|
|
| |
| conn.Flusher.Flush() |
| conn.LastPing = time.Now() |
|
|
| log.Infof("发送事件: sessionId=%s, eventType=%s", conn.SessionID, eventType) |
| log.Debugf("SSE事件发送成功: sessionId=%s, eventType=%s, eventId=%s", conn.SessionID, eventType, id) |
| return nil |
| } |
|
|
| |
| func (sm *SSEManager) GetConnectionCount() int { |
| sm.mutex.RLock() |
| defer sm.mutex.RUnlock() |
| count := len(sm.connections) |
| log.Debugf("SSE连接数统计: count=%d", count) |
| return count |
| } |
|
|
| |
| func (sm *SSEManager) GetConnectionsByUser(username string) []string { |
| sm.mutex.RLock() |
| defer sm.mutex.RUnlock() |
|
|
| var sessionIDs []string |
| for sessionID, conn := range sm.connections { |
| if conn.Username == username { |
| sessionIDs = append(sessionIDs, sessionID) |
| } |
| } |
|
|
| log.Debugf("用户SSE连接查询: username=%s, connectionCount=%d", username, len(sessionIDs)) |
| return sessionIDs |
| } |
|
|
| |
| func (sm *SSEManager) HasConnection(sessionID string) bool { |
| sm.mutex.RLock() |
| defer sm.mutex.RUnlock() |
| _, exists := sm.connections[sessionID] |
| return exists |
| } |
|
|