Spaces:
Running
Running
Update websocket_hub.py
Browse files- websocket_hub.py +35 -7
websocket_hub.py
CHANGED
|
@@ -33,6 +33,7 @@ import json
|
|
| 33 |
import logging
|
| 34 |
import os
|
| 35 |
import time
|
|
|
|
| 36 |
from datetime import datetime
|
| 37 |
from pathlib import Path
|
| 38 |
from typing import Dict, List, Optional, Set
|
|
@@ -146,17 +147,21 @@ def _validate_and_normalize(space_name: str, raw: dict) -> Optional[dict]:
|
|
| 146 |
}
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
| 149 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
# SECTION 2 β CONNECTION MANAGER
|
| 151 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
|
| 153 |
class ConnectionManager:
|
| 154 |
def __init__(self):
|
| 155 |
-
self._publishers: Dict[str, WebSocket]
|
| 156 |
-
self._subscribers: Set[WebSocket]
|
| 157 |
-
self._snapshots: Dict[str, dict]
|
|
|
|
| 158 |
self._lock = asyncio.Lock()
|
| 159 |
-
self._total_ingested: int
|
| 160 |
|
| 161 |
async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
|
| 162 |
await ws.accept()
|
|
@@ -198,6 +203,26 @@ class ConnectionManager:
|
|
| 198 |
self._total_ingested += 1
|
| 199 |
snap_copy = copy.deepcopy(snap)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
await self._broadcast_update(space_name, snap_copy)
|
| 202 |
|
| 203 |
async def _broadcast_update(self, space_name: str, snapshot: dict) -> None:
|
|
@@ -236,6 +261,10 @@ class ConnectionManager:
|
|
| 236 |
def get_all_snapshots(self) -> dict:
|
| 237 |
return dict(self._snapshots)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
def get_health(self) -> dict:
|
| 240 |
now = time.time()
|
| 241 |
return {
|
|
@@ -634,11 +663,10 @@ async def serve_dashboard():
|
|
| 634 |
@app.get("/api/state")
|
| 635 |
async def api_state():
|
| 636 |
"""Full dashboard state polled by hub_dashboard.html every 2 s."""
|
| 637 |
-
h = manager.get_health()
|
| 638 |
rankings = _compute_rankings()
|
| 639 |
return JSONResponse({
|
| 640 |
"rankings": rankings,
|
| 641 |
-
"metric_history":
|
| 642 |
"health": {
|
| 643 |
"hub_connected": True,
|
| 644 |
"spaces_connected": len(manager.get_all_snapshots()),
|
|
@@ -670,4 +698,4 @@ _START_TIME = time.time()
|
|
| 670 |
if __name__ == "__main__":
|
| 671 |
port = int(os.environ.get("PORT", 7860))
|
| 672 |
logger.info(f"π QUASAR Hub starting on port {port}")
|
| 673 |
-
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|
|
|
|
| 33 |
import logging
|
| 34 |
import os
|
| 35 |
import time
|
| 36 |
+
from collections import deque
|
| 37 |
from datetime import datetime
|
| 38 |
from pathlib import Path
|
| 39 |
from typing import Dict, List, Optional, Set
|
|
|
|
| 147 |
}
|
| 148 |
|
| 149 |
|
| 150 |
+
_METRIC_HISTORY_LEN: int = int(os.environ.get("QUASAR_METRIC_HISTORY", "200"))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
# SECTION 2 β CONNECTION MANAGER
|
| 155 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
|
| 157 |
class ConnectionManager:
|
| 158 |
def __init__(self):
|
| 159 |
+
self._publishers: Dict[str, WebSocket] = {}
|
| 160 |
+
self._subscribers: Set[WebSocket] = set()
|
| 161 |
+
self._snapshots: Dict[str, dict] = {}
|
| 162 |
+
self._history: Dict[str, deque] = {} # rolling per-space history
|
| 163 |
self._lock = asyncio.Lock()
|
| 164 |
+
self._total_ingested: int = 0
|
| 165 |
|
| 166 |
async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
|
| 167 |
await ws.accept()
|
|
|
|
| 203 |
self._total_ingested += 1
|
| 204 |
snap_copy = copy.deepcopy(snap)
|
| 205 |
|
| 206 |
+
# ββ Rolling metric history (for sparkline charts in dashboard) ββββββββ
|
| 207 |
+
# Only record a point when training fields arrive AND at least one
|
| 208 |
+
# loss/accuracy field is non-zero (avoids flooding history with empty
|
| 209 |
+
# default-value points before training metrics connect).
|
| 210 |
+
training = snap["training"]
|
| 211 |
+
if normalized["training"] and any(
|
| 212 |
+
training.get(k, 0) != 0
|
| 213 |
+
for k in ("actor_loss", "critic_loss", "avn_loss", "avn_accuracy")
|
| 214 |
+
):
|
| 215 |
+
if space_name not in self._history:
|
| 216 |
+
self._history[space_name] = deque(maxlen=_METRIC_HISTORY_LEN)
|
| 217 |
+
self._history[space_name].append({
|
| 218 |
+
"ts": snap["last_updated"],
|
| 219 |
+
"actor_loss": training.get("actor_loss", 0.0),
|
| 220 |
+
"critic_loss": training.get("critic_loss", 0.0),
|
| 221 |
+
"avn_loss": training.get("avn_loss", 0.0),
|
| 222 |
+
"avn_accuracy": training.get("avn_accuracy", 0.0),
|
| 223 |
+
"training_steps": training.get("training_steps", 0),
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
await self._broadcast_update(space_name, snap_copy)
|
| 227 |
|
| 228 |
async def _broadcast_update(self, space_name: str, snapshot: dict) -> None:
|
|
|
|
| 261 |
def get_all_snapshots(self) -> dict:
|
| 262 |
return dict(self._snapshots)
|
| 263 |
|
| 264 |
+
def get_metric_history(self) -> dict:
|
| 265 |
+
"""Return a plain dict of {space_name: [point, β¦]} for all spaces with history."""
|
| 266 |
+
return {name: list(dq) for name, dq in self._history.items()}
|
| 267 |
+
|
| 268 |
def get_health(self) -> dict:
|
| 269 |
now = time.time()
|
| 270 |
return {
|
|
|
|
| 663 |
@app.get("/api/state")
|
| 664 |
async def api_state():
|
| 665 |
"""Full dashboard state polled by hub_dashboard.html every 2 s."""
|
|
|
|
| 666 |
rankings = _compute_rankings()
|
| 667 |
return JSONResponse({
|
| 668 |
"rankings": rankings,
|
| 669 |
+
"metric_history": manager.get_metric_history(),
|
| 670 |
"health": {
|
| 671 |
"hub_connected": True,
|
| 672 |
"spaces_connected": len(manager.get_all_snapshots()),
|
|
|
|
| 698 |
if __name__ == "__main__":
|
| 699 |
port = int(os.environ.get("PORT", 7860))
|
| 700 |
logger.info(f"π QUASAR Hub starting on port {port}")
|
| 701 |
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|