PYAE1994 commited on
Commit
4498bea
·
verified ·
1 Parent(s): b84e91f

Fix: update api/websocket_manager.py

Browse files
Files changed (1) hide show
  1. api/websocket_manager.py +111 -53
api/websocket_manager.py CHANGED
@@ -1,12 +1,14 @@
1
  """
2
- WebSocket Manager — God Mode+ real-time streaming
 
3
  """
 
4
  import asyncio
5
  import json
6
  import time
7
  import uuid
 
8
  from typing import Dict, List, Optional, Set
9
- from fastapi import WebSocket
10
  import structlog
11
 
12
  log = structlog.get_logger()
@@ -14,63 +16,119 @@ log = structlog.get_logger()
14
 
15
  class WebSocketManager:
16
  def __init__(self):
17
- self._rooms: Dict[str, List[WebSocket]] = {}
18
- self._chat: Dict[str, List[WebSocket]] = {}
19
-
20
- async def connect(self, ws: WebSocket, room: str = "global"):
21
- await ws.accept()
22
- if room not in self._rooms:
23
- self._rooms[room] = []
24
- if ws not in self._rooms[room]:
25
- self._rooms[room].append(ws)
26
- log.debug("WS connected", room=room)
27
-
28
- def disconnect(self, ws: WebSocket, room: str = "global"):
29
- if room in self._rooms:
30
- self._rooms[room] = [w for w in self._rooms[room] if w != ws]
31
- for sid, conns in self._chat.items():
32
- self._chat[sid] = [w for w in conns if w != ws]
33
-
34
- async def emit(self, task_id: str, event_type: str, data: Dict, session_id: str = ""):
35
- room = f"task:{task_id}"
36
- payload = json.dumps({
37
- "type": event_type,
38
- "task_id": task_id,
39
- "session_id": session_id,
40
- "timestamp": time.time(),
41
- "data": data,
42
- "id": uuid.uuid4().hex[:8],
43
- })
44
- await self._broadcast(room, payload)
45
- await self._broadcast("logs", payload)
46
 
47
- async def emit_chat(self, session_id: str, event_type: str, data: Dict):
48
- room = f"chat:{session_id}"
49
- payload = json.dumps({
50
- "type": event_type,
51
- "session_id": session_id,
52
- "timestamp": time.time(),
53
- "data": data,
54
- "id": uuid.uuid4().hex[:8],
 
 
 
 
 
 
 
 
 
 
 
 
55
  })
56
- await self._broadcast(room, payload)
57
 
58
- async def _broadcast(self, room: str, payload: str):
59
- conns = self._rooms.get(room, [])
60
- dead = []
61
- for ws in conns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
- await ws.send_text(payload)
64
  except Exception:
65
- dead.append(ws)
 
66
  for ws in dead:
67
- if room in self._rooms:
68
- self._rooms[room] = [w for w in self._rooms[room] if w != ws]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  async def heartbeat_loop(self):
 
71
  while True:
72
- await asyncio.sleep(20)
73
- hb = json.dumps({"type": "heartbeat", "timestamp": time.time()})
74
- all_rooms = list(self._rooms.keys())
75
- for room in all_rooms:
76
- await self._broadcast(room, hb)
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ WebSocket Connection Manager — Production Grade
3
+ Handles rooms, heartbeats, event buffering, reconnect support
4
  """
5
+
6
  import asyncio
7
  import json
8
  import time
9
  import uuid
10
+ from collections import defaultdict
11
  from typing import Dict, List, Optional, Set
 
12
  import structlog
13
 
14
  log = structlog.get_logger()
 
16
 
17
  class WebSocketManager:
18
  def __init__(self):
19
+ # room set of websockets
20
+ self._rooms: Dict[str, Set] = defaultdict(set)
21
+ # ws → list of rooms
22
+ self._ws_rooms: Dict[object, Set[str]] = defaultdict(set)
23
+ # Event buffer per room (for replay on reconnect)
24
+ self._event_buffer: Dict[str, List] = defaultdict(list)
25
+ self._buffer_max = 100
26
+ # Active connection count
27
+ self._connection_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ async def connect(self, websocket, room: str):
30
+ await websocket.accept()
31
+ self._rooms[room].add(websocket)
32
+ self._ws_rooms[websocket].add(room)
33
+ self._connection_count += 1
34
+ log.info("WS connected", room=room, total=self._connection_count)
35
+
36
+ # Replay buffered events for this room
37
+ buffered = self._event_buffer.get(room, [])[-20:]
38
+ for event in buffered:
39
+ try:
40
+ await websocket.send_json(event)
41
+ except Exception:
42
+ pass
43
+
44
+ await websocket.send_json({
45
+ "type": "connected",
46
+ "room": room,
47
+ "timestamp": time.time(),
48
+ "buffered_events": len(buffered),
49
  })
 
50
 
51
+ def disconnect(self, websocket, room: Optional[str] = None):
52
+ if room:
53
+ self._rooms[room].discard(websocket)
54
+ self._ws_rooms[websocket].discard(room)
55
+ else:
56
+ for r in list(self._ws_rooms.get(websocket, [])):
57
+ self._rooms[r].discard(websocket)
58
+ self._ws_rooms.pop(websocket, None)
59
+ self._connection_count = max(0, self._connection_count - 1)
60
+ log.info("WS disconnected", room=room, total=self._connection_count)
61
+
62
+ async def broadcast(self, room: str, event: dict):
63
+ """Broadcast event to all sockets in a room."""
64
+ if "timestamp" not in event:
65
+ event["timestamp"] = time.time()
66
+ if "id" not in event:
67
+ event["id"] = str(uuid.uuid4())[:8]
68
+
69
+ # Buffer event
70
+ self._event_buffer[room].append(event)
71
+ if len(self._event_buffer[room]) > self._buffer_max:
72
+ self._event_buffer[room].pop(0)
73
+
74
+ dead = set()
75
+ for ws in list(self._rooms.get(room, [])):
76
  try:
77
+ await ws.send_json(event)
78
  except Exception:
79
+ dead.add(ws)
80
+
81
  for ws in dead:
82
+ self.disconnect(ws, room)
83
+
84
+ async def broadcast_global(self, event: dict):
85
+ """Broadcast to ALL connected websockets."""
86
+ for room in list(self._rooms.keys()):
87
+ await self.broadcast(room, event)
88
+
89
+ async def emit(self, task_id: str, event_type: str, data: dict, session_id: str = ""):
90
+ """Emit a structured event to a task room + logs room."""
91
+ event = {
92
+ "type": event_type,
93
+ "task_id": task_id,
94
+ "session_id": session_id,
95
+ "timestamp": time.time(),
96
+ "data": data,
97
+ }
98
+ await self.broadcast(f"task:{task_id}", event)
99
+ await self.broadcast("logs", event)
100
+ await self.broadcast("agent_status", {
101
+ "type": "agent_event",
102
+ "task_id": task_id,
103
+ "event_type": event_type,
104
+ "timestamp": time.time(),
105
+ })
106
+
107
+ async def emit_chat(self, session_id: str, event_type: str, data: dict):
108
+ """Emit event to a chat session room."""
109
+ event = {
110
+ "type": event_type,
111
+ "session_id": session_id,
112
+ "timestamp": time.time(),
113
+ "data": data,
114
+ }
115
+ await self.broadcast(f"chat:{session_id}", event)
116
 
117
  async def heartbeat_loop(self):
118
+ """Send heartbeat to all connections every 15s."""
119
  while True:
120
+ await asyncio.sleep(15)
121
+ heartbeat = {
122
+ "type": "heartbeat",
123
+ "timestamp": time.time(),
124
+ "connections": self._connection_count,
125
+ }
126
+ for room in list(self._rooms.keys()):
127
+ await self.broadcast(room, heartbeat)
128
+
129
+ def get_stats(self) -> dict:
130
+ return {
131
+ "total_connections": self._connection_count,
132
+ "rooms": {r: len(ws) for r, ws in self._rooms.items()},
133
+ "buffered_events": {r: len(e) for r, e in self._event_buffer.items()},
134
+ }