KarlQuant commited on
Commit
a41e1f5
Β·
verified Β·
1 Parent(s): 012fb5c

Update websocket_hub.py

Browse files
Files changed (1) hide show
  1. websocket_hub.py +122 -9
websocket_hub.py CHANGED
@@ -155,6 +155,24 @@ _METRIC_HISTORY_LEN: int = int(os.environ.get("QUASAR_METRIC_HISTORY", "200"))
155
  # ══════════════════════════════════════════════════════════════════════════════════════
156
 
157
  class ConnectionManager:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def __init__(self):
159
  self._publishers: Dict[str, WebSocket] = {}
160
  self._subscribers: Set[WebSocket] = set()
@@ -162,6 +180,7 @@ class ConnectionManager:
162
  self._history: Dict[str, deque] = {} # rolling per-space history
163
  self._lock = asyncio.Lock()
164
  self._total_ingested: int = 0
 
165
 
166
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
167
  await ws.accept()
@@ -261,6 +280,14 @@ class ConnectionManager:
261
  def get_all_snapshots(self) -> dict:
262
  return dict(self._snapshots)
263
 
 
 
 
 
 
 
 
 
264
  def get_metric_history(self) -> dict:
265
  """Return a plain dict of {space_name: [point, …]} for all spaces with history."""
266
  return {name: list(dq) for name, dq in self._history.items()}
@@ -483,25 +510,71 @@ async def ws_publisher_endpoint(websocket: WebSocket, space_name: str):
483
 
484
  msg_type = data.get("type", "")
485
 
 
 
 
 
486
  if msg_type == "metrics":
 
487
  await manager.ingest(space_name, {
488
  "training": data.get("training", {}),
489
  "voting": data.get("voting", {}),
490
  })
 
491
  elif msg_type == "training":
492
- await manager.ingest(space_name, {
493
- "training": data.get("data", {}),
494
- "voting": {},
495
- })
 
 
 
 
 
 
 
 
 
 
 
 
496
  elif msg_type == "voting":
497
- await manager.ingest(space_name, {
498
- "training": {},
499
- "voting": data.get("data", {}),
500
- })
 
 
501
  elif msg_type in ("heartbeat", "identify", "ping"):
502
  pass
 
503
  else:
504
- logger.debug(f"[{space_name}] Unrecognised type '{msg_type}' β€” dropped")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  except WebSocketDisconnect:
507
  pass
@@ -555,6 +628,46 @@ async def get_health():
555
  }
556
 
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  # ══════════════════════════════════════════════════════════════════════════════════════
559
  # SECTION 7 β€” TRADE API (native β€” replaces patch_websocket_hub.py)
560
  # ══════════════════════════════════════════════════════════════════════════════════════
 
155
  # ══════════════════════════════════════════════════════════════════════════════════════
156
 
157
  class ConnectionManager:
158
+ # All training field names the hub will accept (including common ranker aliases)
159
+ _TRAINING_KEYS: frozenset = frozenset({
160
+ "actor_loss", "critic_loss", "avn_loss", "avn_accuracy", "training_steps",
161
+ "a_loss", "c_loss", "loss_actor", "loss_critic", "loss_avn",
162
+ "acc", "accuracy", "step", "steps", "n_steps",
163
+ })
164
+ _TRAINING_ALIAS: dict = {
165
+ "a_loss": "actor_loss", "loss_actor": "actor_loss",
166
+ "c_loss": "critic_loss", "loss_critic": "critic_loss",
167
+ "loss_avn": "avn_loss",
168
+ "acc": "avn_accuracy","accuracy": "avn_accuracy",
169
+ "step": "training_steps","steps": "training_steps","n_steps": "training_steps",
170
+ }
171
+ _VOTING_KEYS: frozenset = frozenset({
172
+ "dominant_signal", "buy_count", "sell_count", "last_price", "signal_source",
173
+ "signal", "buy", "sell",
174
+ })
175
+
176
  def __init__(self):
177
  self._publishers: Dict[str, WebSocket] = {}
178
  self._subscribers: Set[WebSocket] = set()
 
180
  self._history: Dict[str, deque] = {} # rolling per-space history
181
  self._lock = asyncio.Lock()
182
  self._total_ingested: int = 0
183
+ self._msg_counts: Dict[str, Dict[str, int]] = {} # {space: {msg_type: count}}
184
 
185
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
186
  await ws.accept()
 
280
  def get_all_snapshots(self) -> dict:
281
  return dict(self._snapshots)
282
 
283
+ def record_msg(self, space_name: str, msg_type: str) -> None:
284
+ """Increment per-space message type counter (non-blocking, called from publisher loop)."""
285
+ counts = self._msg_counts.setdefault(space_name, {})
286
+ counts[msg_type] = counts.get(msg_type, 0) + 1
287
+
288
+ def get_msg_counts(self) -> dict:
289
+ return {s: dict(c) for s, c in self._msg_counts.items()}
290
+
291
  def get_metric_history(self) -> dict:
292
  """Return a plain dict of {space_name: [point, …]} for all spaces with history."""
293
  return {name: list(dq) for name, dq in self._history.items()}
 
510
 
511
  msg_type = data.get("type", "")
512
 
513
+ # ── Track per-space message type counts (for /api/debug/hub) ─────────
514
+ manager.record_msg(space_name, msg_type)
515
+
516
+ # ── Route by type ────────────────────────────────────────────────────
517
  if msg_type == "metrics":
518
+ # Combined payload: top-level "training" and "voting" dicts
519
  await manager.ingest(space_name, {
520
  "training": data.get("training", {}),
521
  "voting": data.get("voting", {}),
522
  })
523
+
524
  elif msg_type == "training":
525
+ # Bug A fix: try "data" wrapper first, then fall back to top-level fields.
526
+ # Some rankers send {"type":"training","data":{...}},
527
+ # others send {"type":"training","actor_loss":..., ...} directly.
528
+ training_raw = data.get("data") or {
529
+ manager._TRAINING_ALIAS.get(k, k): v
530
+ for k, v in data.items()
531
+ if k in manager._TRAINING_KEYS and k != "type"
532
+ }
533
+ if training_raw:
534
+ logger.info(
535
+ f"[{space_name}] βš™ training msg | "
536
+ f"keys={list(training_raw.keys())} | "
537
+ f"actor_loss={training_raw.get('actor_loss', training_raw.get('a_loss', 'β€”'))}"
538
+ )
539
+ await manager.ingest(space_name, {"training": training_raw, "voting": {}})
540
+
541
  elif msg_type == "voting":
542
+ voting_raw = data.get("data") or {
543
+ k: v for k, v in data.items()
544
+ if k in manager._VOTING_KEYS and k != "type"
545
+ }
546
+ await manager.ingest(space_name, {"training": {}, "voting": voting_raw})
547
+
548
  elif msg_type in ("heartbeat", "identify", "ping"):
549
  pass
550
+
551
  else:
552
+ # Bug B fix: don't silently swallow. Try to rescue training/voting
553
+ # fields that live at the top level of an unrecognised message type.
554
+ rescued_training = {
555
+ manager._TRAINING_ALIAS.get(k, k): v
556
+ for k, v in data.items()
557
+ if k in manager._TRAINING_KEYS
558
+ }
559
+ rescued_voting = {
560
+ k: v for k, v in data.items()
561
+ if k in manager._VOTING_KEYS
562
+ }
563
+ if rescued_training or rescued_voting:
564
+ logger.warning(
565
+ f"[{space_name}] ⚠ Unknown type='{msg_type}' β€” "
566
+ f"auto-rescued: training_keys={list(rescued_training.keys())} "
567
+ f"voting_keys={list(rescued_voting.keys())}"
568
+ )
569
+ await manager.ingest(space_name, {
570
+ "training": rescued_training,
571
+ "voting": rescued_voting,
572
+ })
573
+ else:
574
+ logger.warning(
575
+ f"[{space_name}] ⚠ Unknown type='{msg_type}' with no "
576
+ f"extractable fields β€” dropped. Full keys: {list(data.keys())}"
577
+ )
578
 
579
  except WebSocketDisconnect:
580
  pass
 
628
  }
629
 
630
 
631
+ @app.get("/api/debug/hub")
632
+ async def api_debug_hub():
633
+ """
634
+ Diagnostic endpoint β€” exposes exactly what the hub has received and stored.
635
+
636
+ Returns per-space:
637
+ msg_counts β€” how many messages of each type arrived
638
+ snapshot β€” current stored training + voting values
639
+ history_len β€” number of history points recorded
640
+
641
+ Use this to confirm whether training messages are arriving and being stored.
642
+ If msg_counts shows training=0 for a space, the asset space is NOT sending
643
+ training messages. If training > 0 but snapshot.training shows zeros, there
644
+ is a field-name or format mismatch.
645
+ """
646
+ snapshots = manager.get_all_snapshots()
647
+ msg_counts = manager.get_msg_counts()
648
+ history_len = {name: len(dq) for name, dq in manager._history.items()}
649
+
650
+ spaces = {}
651
+ for name in set(list(snapshots.keys()) + list(msg_counts.keys())):
652
+ snap = snapshots.get(name, {})
653
+ spaces[name] = {
654
+ "msg_counts": msg_counts.get(name, {}),
655
+ "history_len": history_len.get(name, 0),
656
+ "training": snap.get("training", {}),
657
+ "voting": snap.get("voting", {}),
658
+ "last_updated": snap.get("last_updated", 0),
659
+ "stale_s": round(time.time() - snap.get("last_updated", time.time()), 1),
660
+ }
661
+
662
+ return JSONResponse({
663
+ "spaces": spaces,
664
+ "total_ingested": manager._total_ingested,
665
+ "publisher_count": len(manager._publishers),
666
+ "subscriber_count": len(manager._subscribers),
667
+ "timestamp": datetime.utcnow().isoformat() + "Z",
668
+ })
669
+
670
+
671
  # ══════════════════════════════════════════════════════════════════════════════════════
672
  # SECTION 7 β€” TRADE API (native β€” replaces patch_websocket_hub.py)
673
  # ══════════════════════════════════════════════════════════════════════════════════════