KarlQuant commited on
Commit
b734b4b
Β·
verified Β·
1 Parent(s): 3186156

Update websocket_hub.py

Browse files
Files changed (1) hide show
  1. websocket_hub.py +35 -7
websocket_hub.py CHANGED
@@ -33,6 +33,7 @@ import json
33
  import logging
34
  import os
35
  import time
 
36
  from datetime import datetime
37
  from pathlib import Path
38
  from typing import Dict, List, Optional, Set
@@ -146,17 +147,21 @@ def _validate_and_normalize(space_name: str, raw: dict) -> Optional[dict]:
146
  }
147
 
148
 
 
 
 
149
  # ══════════════════════════════════════════════════════════════════════════════════════
150
  # SECTION 2 β€” CONNECTION MANAGER
151
  # ══════════════════════════════════════════════════════════════════════════════════════
152
 
153
  class ConnectionManager:
154
  def __init__(self):
155
- self._publishers: Dict[str, WebSocket] = {}
156
- self._subscribers: Set[WebSocket] = set()
157
- self._snapshots: Dict[str, dict] = {}
 
158
  self._lock = asyncio.Lock()
159
- self._total_ingested: int = 0
160
 
161
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
162
  await ws.accept()
@@ -198,6 +203,26 @@ class ConnectionManager:
198
  self._total_ingested += 1
199
  snap_copy = copy.deepcopy(snap)
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  await self._broadcast_update(space_name, snap_copy)
202
 
203
  async def _broadcast_update(self, space_name: str, snapshot: dict) -> None:
@@ -236,6 +261,10 @@ class ConnectionManager:
236
  def get_all_snapshots(self) -> dict:
237
  return dict(self._snapshots)
238
 
 
 
 
 
239
  def get_health(self) -> dict:
240
  now = time.time()
241
  return {
@@ -634,11 +663,10 @@ async def serve_dashboard():
634
  @app.get("/api/state")
635
  async def api_state():
636
  """Full dashboard state polled by hub_dashboard.html every 2 s."""
637
- h = manager.get_health()
638
  rankings = _compute_rankings()
639
  return JSONResponse({
640
  "rankings": rankings,
641
- "metric_history": {},
642
  "health": {
643
  "hub_connected": True,
644
  "spaces_connected": len(manager.get_all_snapshots()),
@@ -670,4 +698,4 @@ _START_TIME = time.time()
670
  if __name__ == "__main__":
671
  port = int(os.environ.get("PORT", 7860))
672
  logger.info(f"πŸš€ QUASAR Hub starting on port {port}")
673
- uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
 
33
  import logging
34
  import os
35
  import time
36
+ from collections import deque
37
  from datetime import datetime
38
  from pathlib import Path
39
  from typing import Dict, List, Optional, Set
 
147
  }
148
 
149
 
150
+ _METRIC_HISTORY_LEN: int = int(os.environ.get("QUASAR_METRIC_HISTORY", "200"))
151
+
152
+
153
  # ══════════════════════════════════════════════════════════════════════════════════════
154
  # SECTION 2 β€” CONNECTION MANAGER
155
  # ══════════════════════════════════════════════════════════════════════════════════════
156
 
157
  class ConnectionManager:
158
  def __init__(self):
159
+ self._publishers: Dict[str, WebSocket] = {}
160
+ self._subscribers: Set[WebSocket] = set()
161
+ self._snapshots: Dict[str, dict] = {}
162
+ self._history: Dict[str, deque] = {} # rolling per-space history
163
  self._lock = asyncio.Lock()
164
+ self._total_ingested: int = 0
165
 
166
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
167
  await ws.accept()
 
203
  self._total_ingested += 1
204
  snap_copy = copy.deepcopy(snap)
205
 
206
+ # ── Rolling metric history (for sparkline charts in dashboard) ────────
207
+ # Only record a point when training fields arrive AND at least one
208
+ # loss/accuracy field is non-zero (avoids flooding history with empty
209
+ # default-value points before training metrics connect).
210
+ training = snap["training"]
211
+ if normalized["training"] and any(
212
+ training.get(k, 0) != 0
213
+ for k in ("actor_loss", "critic_loss", "avn_loss", "avn_accuracy")
214
+ ):
215
+ if space_name not in self._history:
216
+ self._history[space_name] = deque(maxlen=_METRIC_HISTORY_LEN)
217
+ self._history[space_name].append({
218
+ "ts": snap["last_updated"],
219
+ "actor_loss": training.get("actor_loss", 0.0),
220
+ "critic_loss": training.get("critic_loss", 0.0),
221
+ "avn_loss": training.get("avn_loss", 0.0),
222
+ "avn_accuracy": training.get("avn_accuracy", 0.0),
223
+ "training_steps": training.get("training_steps", 0),
224
+ })
225
+
226
  await self._broadcast_update(space_name, snap_copy)
227
 
228
  async def _broadcast_update(self, space_name: str, snapshot: dict) -> None:
 
261
  def get_all_snapshots(self) -> dict:
262
  return dict(self._snapshots)
263
 
264
+ def get_metric_history(self) -> dict:
265
+ """Return a plain dict of {space_name: [point, …]} for all spaces with history."""
266
+ return {name: list(dq) for name, dq in self._history.items()}
267
+
268
  def get_health(self) -> dict:
269
  now = time.time()
270
  return {
 
663
  @app.get("/api/state")
664
  async def api_state():
665
  """Full dashboard state polled by hub_dashboard.html every 2 s."""
 
666
  rankings = _compute_rankings()
667
  return JSONResponse({
668
  "rankings": rankings,
669
+ "metric_history": manager.get_metric_history(),
670
  "health": {
671
  "hub_connected": True,
672
  "spaces_connected": len(manager.get_all_snapshots()),
 
698
  if __name__ == "__main__":
699
  port = int(os.environ.get("PORT", 7860))
700
  logger.info(f"πŸš€ QUASAR Hub starting on port {port}")
701
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")