saadrizvi09 commited on
Commit
5b93121
·
1 Parent(s): 53dbc28
Files changed (2) hide show
  1. routers/ws.py +6 -6
  2. ws_manager.py +15 -3
routers/ws.py CHANGED
@@ -257,10 +257,10 @@ async def session_ws(websocket: WebSocket, session_id: str):
257
  await cart_incr(session_id, item_id, delta)
258
  await cart_bump_version(session_id)
259
  enriched = await get_enriched_cart(session_id, restaurant_id)
260
- asyncio.create_task(
261
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
262
  )
263
- logger.info(f"[WS] cart_add item={item_id} session={session_id[:8]} → broadcast to {manager.count(session_id)} users")
264
 
265
  elif msg_type == "cart_remove":
266
  if not restaurant_id:
@@ -271,10 +271,10 @@ async def session_ws(websocket: WebSocket, session_id: str):
271
  await cart_remove_item(session_id, item_id)
272
  await cart_bump_version(session_id)
273
  enriched = await get_enriched_cart(session_id, restaurant_id)
274
- asyncio.create_task(
275
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
276
  )
277
- logger.info(f"[WS] cart_remove item={item_id} session={session_id[:8]} → broadcast to {manager.count(session_id)} users")
278
 
279
  elif msg_type == "cart_set_qty":
280
  if not restaurant_id:
@@ -289,10 +289,10 @@ async def session_ws(websocket: WebSocket, session_id: str):
289
  await cart_set_qty(session_id, item_id, qty)
290
  await cart_bump_version(session_id)
291
  enriched = await get_enriched_cart(session_id, restaurant_id)
292
- asyncio.create_task(
293
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
294
  )
295
- logger.info(f"[WS] cart_set_qty item={item_id} qty={qty} session={session_id[:8]} → broadcast to {manager.count(session_id)} users")
296
 
297
  except json.JSONDecodeError:
298
  logger.warning(f"[WS] Bad JSON from {session_id[:8]}")
 
257
  await cart_incr(session_id, item_id, delta)
258
  await cart_bump_version(session_id)
259
  enriched = await get_enriched_cart(session_id, restaurant_id)
260
+ task = asyncio.create_task(
261
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
262
  )
263
+ logger.info(f"[WS] cart_add item={item_id} session={session_id[:8]} → created broadcast task for {manager.count(session_id)} users")
264
 
265
  elif msg_type == "cart_remove":
266
  if not restaurant_id:
 
271
  await cart_remove_item(session_id, item_id)
272
  await cart_bump_version(session_id)
273
  enriched = await get_enriched_cart(session_id, restaurant_id)
274
+ task = asyncio.create_task(
275
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
276
  )
277
+ logger.info(f"[WS] cart_remove item={item_id} session={session_id[:8]} → created broadcast task for {manager.count(session_id)} users")
278
 
279
  elif msg_type == "cart_set_qty":
280
  if not restaurant_id:
 
289
  await cart_set_qty(session_id, item_id, qty)
290
  await cart_bump_version(session_id)
291
  enriched = await get_enriched_cart(session_id, restaurant_id)
292
+ task = asyncio.create_task(
293
  manager.broadcast(session_id, {"type": "cart_update", "cart": enriched})
294
  )
295
+ logger.info(f"[WS] cart_set_qty item={item_id} qty={qty} session={session_id[:8]} → created broadcast task for {manager.count(session_id)} users")
296
 
297
  except json.JSONDecodeError:
298
  logger.warning(f"[WS] Bad JSON from {session_id[:8]}")
ws_manager.py CHANGED
@@ -1,9 +1,11 @@
1
  """WebSocket connection manager — singleton shared across routers."""
2
 
3
  from __future__ import annotations
4
- import json
5
  from fastapi import WebSocket
6
 
 
 
7
 
8
  class ConnectionManager:
9
  """Tracks active WebSocket connections per session_id and broadcasts messages."""
@@ -16,25 +18,35 @@ class ConnectionManager:
16
  if session_id not in self.active_connections:
17
  self.active_connections[session_id] = []
18
  self.active_connections[session_id].append(websocket)
 
19
 
20
  def disconnect(self, session_id: str, websocket: WebSocket):
21
  if session_id in self.active_connections:
22
  try:
23
  self.active_connections[session_id].remove(websocket)
 
24
  except ValueError:
25
  pass
26
  if not self.active_connections[session_id]:
27
  del self.active_connections[session_id]
 
28
 
29
  async def broadcast(self, session_id: str, message: dict):
30
  if session_id not in self.active_connections:
 
31
  return
 
 
32
  dead: list[WebSocket] = []
33
- for conn in self.active_connections[session_id]:
 
34
  try:
35
  await conn.send_json(message)
36
- except Exception:
 
 
37
  dead.append(conn)
 
38
  for d in dead:
39
  self.disconnect(session_id, d)
40
 
 
1
  """WebSocket connection manager — singleton shared across routers."""
2
 
3
  from __future__ import annotations
4
+ import json, logging
5
  from fastapi import WebSocket
6
 
7
+ logger = logging.getLogger(__name__)
8
+
9
 
10
  class ConnectionManager:
11
  """Tracks active WebSocket connections per session_id and broadcasts messages."""
 
18
  if session_id not in self.active_connections:
19
  self.active_connections[session_id] = []
20
  self.active_connections[session_id].append(websocket)
21
+ logger.info(f"[Manager] Connected to {session_id[:8]}, total connections: {len(self.active_connections[session_id])}")
22
 
23
  def disconnect(self, session_id: str, websocket: WebSocket):
24
  if session_id in self.active_connections:
25
  try:
26
  self.active_connections[session_id].remove(websocket)
27
+ logger.info(f"[Manager] Disconnected from {session_id[:8]}, remaining: {len(self.active_connections[session_id])}")
28
  except ValueError:
29
  pass
30
  if not self.active_connections[session_id]:
31
  del self.active_connections[session_id]
32
+ logger.info(f"[Manager] No connections left for {session_id[:8]}, removed channel")
33
 
34
  async def broadcast(self, session_id: str, message: dict):
35
  if session_id not in self.active_connections:
36
+ logger.warning(f"[Manager] ⚠️ Cannot broadcast to {session_id[:8]} - no active connections!")
37
  return
38
+ conns = self.active_connections[session_id]
39
+ logger.info(f"[Manager] 📡 Broadcasting '{message.get('type')}' to {len(conns)} connection(s) in {session_id[:8]}")
40
  dead: list[WebSocket] = []
41
+ success_count = 0
42
+ for conn in conns:
43
  try:
44
  await conn.send_json(message)
45
+ success_count += 1
46
+ except Exception as e:
47
+ logger.error(f"[Manager] Failed to send to connection: {e}")
48
  dead.append(conn)
49
+ logger.info(f"[Manager] ✅ Sent to {success_count}/{len(conns)} connections, {len(dead)} dead")
50
  for d in dead:
51
  self.disconnect(session_id, d)
52