KarlQuant commited on
Commit
0ec6d5e
Β·
verified Β·
1 Parent(s): ff4ed29

Upload websocket_hub.py

Browse files
Files changed (1) hide show
  1. websocket_hub.py +92 -144
websocket_hub.py CHANGED
@@ -45,7 +45,7 @@ from pathlib import Path
45
  from typing import Dict, List, Optional, Set
46
 
47
  import uvicorn
48
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
49
  from fastapi.middleware.cors import CORSMiddleware
50
  from fastapi.responses import FileResponse, JSONResponse
51
 
@@ -180,14 +180,13 @@ class ConnectionManager:
180
  })
181
 
182
  def __init__(self):
183
- self._publishers: Dict[str, WebSocket] = {}
184
- self._subscribers: Set[WebSocket] = set()
185
- self._top3_subscribers: Set[WebSocket] = set() # PATCH A
186
- self._snapshots: Dict[str, dict] = {}
187
- self._history: Dict[str, deque] = {} # rolling per-space history
188
- self._lock = asyncio.Lock()
189
- self._total_ingested: int = 0
190
- self._msg_counts: Dict[str, Dict[str, int]] = {} # {space: {msg_type: count}}
191
 
192
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
193
  await ws.accept()
@@ -270,7 +269,6 @@ class ConnectionManager:
270
  async with self._lock:
271
  for ws in dead:
272
  self._subscribers.discard(ws)
273
- await self.broadcast_top3() # PATCH C β€” push top-3 on every ingest
274
 
275
  async def send_initial_state(self, ws: WebSocket) -> None:
276
  async with self._lock:
@@ -313,88 +311,6 @@ class ConnectionManager:
313
  "subscriber_count": len(self._subscribers),
314
  }
315
 
316
- # ── PATCH B β€” Top-3 subscriber management ─────────────────────────────────────
317
-
318
- async def register_top3_subscriber(self, ws: WebSocket) -> None:
319
- await ws.accept()
320
- async with self._lock:
321
- self._top3_subscribers.add(ws)
322
- # Push current rankings immediately on connect
323
- await self.broadcast_top3(target=ws)
324
- logger.info(f"πŸ† Top-3 subscriber connected (total={len(self._top3_subscribers)})")
325
-
326
- async def unregister_top3_subscriber(self, ws: WebSocket) -> None:
327
- async with self._lock:
328
- self._top3_subscribers.discard(ws)
329
- logger.info("πŸ† Top-3 subscriber disconnected")
330
-
331
- async def broadcast_top3(self, target: Optional[WebSocket] = None) -> None:
332
- """
333
- Compute top-3 ranked assets and push them to:
334
- β€’ a single WebSocket (target=ws) β€” used on initial connect
335
- β€’ all _top3_subscribers (target=None) β€” used on every ingest
336
- """
337
- # ── Build ranked list (mirrors _compute_rankings() logic) ─────────────────
338
- ranked = []
339
- async with self._lock:
340
- snapshots = dict(self._snapshots)
341
-
342
- for name, snap in snapshots.items():
343
- training = snap.get("training", {})
344
- voting = snap.get("voting", {})
345
- buy = voting.get("buy_count", 0)
346
- sell = voting.get("sell_count", 0)
347
- total = buy + sell
348
- sig_conf = (max(buy, sell) / total) if total > 0 else 0.0
349
- avn_acc = training.get("avn_accuracy", 0.0)
350
- score = round(sig_conf - avn_acc, 6)
351
- ranked.append({
352
- "rank": 0, # filled in below
353
- "space_name": name,
354
- "score": score,
355
- "signal_confidence": round(sig_conf, 6),
356
- "avn_accuracy": round(avn_acc, 6),
357
- "dominant_signal": voting.get("dominant_signal", "NEUTRAL"),
358
- "buy_count": buy,
359
- "sell_count": sell,
360
- "last_price": voting.get("last_price", 0.0),
361
- "training_steps": training.get("training_steps", 0),
362
- "actor_loss": training.get("actor_loss", 0.0),
363
- "critic_loss": training.get("critic_loss", 0.0),
364
- "avn_loss": training.get("avn_loss", 0.0),
365
- "last_updated": snap.get("last_updated", 0.0),
366
- })
367
-
368
- ranked.sort(key=lambda r: r["score"], reverse=True)
369
- for i, r in enumerate(ranked):
370
- r["rank"] = i + 1
371
-
372
- message = json.dumps({
373
- "type": "top3_rankings",
374
- "rankings": ranked[:3],
375
- "total_assets": len(ranked),
376
- "hub_timestamp": time.time(),
377
- })
378
-
379
- # ── Deliver ───────────────────────────────────────────────────────────────
380
- if target is not None:
381
- try:
382
- await target.send_text(message)
383
- except Exception:
384
- pass
385
- return
386
-
387
- dead = []
388
- for ws in list(self._top3_subscribers):
389
- try:
390
- await ws.send_text(message)
391
- except Exception:
392
- dead.append(ws)
393
- if dead:
394
- async with self._lock:
395
- for ws in dead:
396
- self._top3_subscribers.discard(ws)
397
-
398
 
399
  # ══════════════════════════════════════════════════════════════════════════════════════
400
  # SECTION 3 β€” HUB TRADE STORE (in-memory, fed by WebSocket messages)
@@ -539,6 +455,13 @@ _LOG_DIR = os.environ.get("RANKER_LOG_DIR", "/app/ranker_logs")
539
  _hub_trades = HubTradeStore()
540
  logger.info("βœ… HubTradeStore initialised β€” awaiting trade_opened/trade_closed WS messages")
541
 
 
 
 
 
 
 
 
542
 
543
  # ══════════════════════════════════════════════════════════════════════════════════════
544
  # SECTION 4 β€” FASTAPI APPLICATION
@@ -1079,6 +1002,51 @@ _HTML_PATH = Path(os.environ.get(
1079
 
1080
 
1081
  def _compute_rankings() -> List[dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1082
  ranked: List[dict] = []
1083
  for name, snap in manager.get_all_snapshots().items():
1084
  training = snap.get("training", {})
@@ -1093,6 +1061,7 @@ def _compute_rankings() -> List[dict]:
1093
  "rank": 0,
1094
  "space_name": name,
1095
  "score": score,
 
1096
  "signal_confidence": round(sig_conf, 6),
1097
  "avn_accuracy": round(avn_acc, 6),
1098
  "dominant_signal": voting.get("dominant_signal", "NEUTRAL"),
@@ -1103,6 +1072,7 @@ def _compute_rankings() -> List[dict]:
1103
  "critic_loss": training.get("critic_loss", 0.0),
1104
  "avn_loss": training.get("avn_loss", 0.0),
1105
  "last_updated": snap.get("last_updated", 0.0),
 
1106
  })
1107
  ranked.sort(key=lambda r: r["score"], reverse=True)
1108
  for i, r in enumerate(ranked):
@@ -1131,6 +1101,36 @@ async def serve_dashboard():
1131
  )
1132
 
1133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1134
  @app.get("/api/state")
1135
  async def api_state():
1136
  """Full dashboard state polled by hub_dashboard.html every 2 s."""
@@ -1162,58 +1162,6 @@ async def api_state():
1162
  _START_TIME = time.time()
1163
 
1164
 
1165
- # ══════════════════════════════════════════════════════════════════════════════════════
1166
- # SECTION 8b β€” TOP-3 RANKINGS ENDPOINTS (PATCH D)
1167
- # ══════════════════════════════════════════════════════════════════════════════════════
1168
-
1169
- @app.get("/api/rankings/top3")
1170
- async def api_top3_rankings():
1171
- """
1172
- GET /api/rankings/top3
1173
-
1174
- Returns the top 3 assets scored by:
1175
- score = signal_confidence βˆ’ avn_accuracy
1176
- (higher = stronger signal conviction with lower training noise)
1177
-
1178
- Suitable for one-shot polling from your IDE or any HTTP client.
1179
- For real-time streaming connect to ws://<host>/ws/top3 instead.
1180
- """
1181
- rankings = _compute_rankings()[:3]
1182
- return JSONResponse({
1183
- "top3": rankings,
1184
- "total_assets": len(_compute_rankings()),
1185
- "timestamp": datetime.utcnow().isoformat() + "Z",
1186
- })
1187
-
1188
-
1189
- @app.websocket("/ws/top3")
1190
- async def ws_top3_endpoint(websocket: WebSocket):
1191
- """
1192
- WS /ws/top3
1193
-
1194
- Pushes a 'top3_rankings' message to the client every time any asset
1195
- publishes a metrics update. The payload is identical to /api/rankings/top3
1196
- but delivered in real time.
1197
-
1198
- Message shape:
1199
- {
1200
- "type": "top3_rankings",
1201
- "rankings": [ { rank, space_name, score, dominant_signal, … }, … ],
1202
- "total_assets": <int>,
1203
- "hub_timestamp": <unix_float>
1204
- }
1205
- """
1206
- await manager.register_top3_subscriber(websocket)
1207
- try:
1208
- while True:
1209
- # Keep connection alive; we only send, never read from clients here.
1210
- await websocket.receive_text()
1211
- except Exception:
1212
- pass
1213
- finally:
1214
- await manager.unregister_top3_subscriber(websocket)
1215
-
1216
-
1217
  # ══════════════════════════════════════════════════════════════════════════════════════
1218
  # SECTION 9 β€” ENTRY POINT
1219
  # ══════════════════════════════════════════════════════════════════════════════════════
 
45
  from typing import Dict, List, Optional, Set
46
 
47
  import uvicorn
48
+ from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
49
  from fastapi.middleware.cors import CORSMiddleware
50
  from fastapi.responses import FileResponse, JSONResponse
51
 
 
180
  })
181
 
182
  def __init__(self):
183
+ self._publishers: Dict[str, WebSocket] = {}
184
+ self._subscribers: Set[WebSocket] = set()
185
+ self._snapshots: Dict[str, dict] = {}
186
+ self._history: Dict[str, deque] = {} # rolling per-space history
187
+ self._lock = asyncio.Lock()
188
+ self._total_ingested: int = 0
189
+ self._msg_counts: Dict[str, Dict[str, int]] = {} # {space: {msg_type: count}}
 
190
 
191
  async def register_publisher(self, space_name: str, ws: WebSocket) -> None:
192
  await ws.accept()
 
269
  async with self._lock:
270
  for ws in dead:
271
  self._subscribers.discard(ws)
 
272
 
273
  async def send_initial_state(self, ws: WebSocket) -> None:
274
  async with self._lock:
 
311
  "subscriber_count": len(self._subscribers),
312
  }
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  # ══════════════════════════════════════════════════════════════════════════════════════
316
  # SECTION 3 β€” HUB TRADE STORE (in-memory, fed by WebSocket messages)
 
455
  _hub_trades = HubTradeStore()
456
  logger.info("βœ… HubTradeStore initialised β€” awaiting trade_opened/trade_closed WS messages")
457
 
458
+ # ── AXRVI live rankings store ─────────────────────────────────────────────────────────
459
+ # Populated by POST /api/axrvi/rankings from the Executo ranker after every
460
+ # rank_and_gate() cycle (~every 5s). Falls back to hub-snapshot scoring when stale.
461
+ _axrvi_rankings: List[dict] = []
462
+ _axrvi_rankings_ts: float = 0.0
463
+ _AXRVI_RANKINGS_TTL: float = 30.0 # seconds before falling back to snapshot scoring
464
+
465
 
466
  # ══════════════════════════════════════════════════════════════════════════════════════
467
  # SECTION 4 β€” FASTAPI APPLICATION
 
1002
 
1003
 
1004
  def _compute_rankings() -> List[dict]:
1005
+ """
1006
+ Build the rankings list served by /api/state.
1007
+
1008
+ Priority order:
1009
+ 1. Live AXRVI rankings pushed by the Executo ranker via
1010
+ POST /api/axrvi/rankings (within the last 30 s).
1011
+ These contain real softmax-Shreve priorities from AXRVINet.
1012
+ 2. Fallback: hub-snapshot vote-ratio scoring used before the
1013
+ ranker connects or if the push is stale (e.g. ranker restart).
1014
+ """
1015
+ global _axrvi_rankings, _axrvi_rankings_ts
1016
+
1017
+ # ── Path 1: fresh AXRVI rankings ────────────────────────────────────────
1018
+ if _axrvi_rankings and (time.time() - _axrvi_rankings_ts) < _AXRVI_RANKINGS_TTL:
1019
+ snapshots = manager.get_all_snapshots()
1020
+ merged: List[dict] = []
1021
+ for r in _axrvi_rankings:
1022
+ name = r.get("space_name", "")
1023
+ snap = snapshots.get(name, {})
1024
+ training = snap.get("training", {})
1025
+ voting = snap.get("voting", {})
1026
+ buy = voting.get("buy_count", r.get("buy_count", 0))
1027
+ sell = voting.get("sell_count", r.get("sell_count", 0))
1028
+ merged.append({
1029
+ # Core AXRVI fields β€” these are the live ranker values
1030
+ "rank": r.get("rank", 0),
1031
+ "space_name": name,
1032
+ "score": r.get("score", 0.0),
1033
+ "final_priority": r.get("final_priority", r.get("score", 0.0)),
1034
+ "signal_confidence": r.get("signal_confidence",0.0),
1035
+ "dominant_signal": r.get("dominant_signal", "NEUTRAL"),
1036
+ "avn_accuracy": r.get("avn_accuracy", 0.0),
1037
+ "epistemic_std": r.get("epistemic_std", 0.0),
1038
+ "training_steps": r.get("training_steps", training.get("training_steps", 0)),
1039
+ # Hub-snapshot fields merged in (latest available)
1040
+ "actor_loss": training.get("actor_loss", 0.0),
1041
+ "critic_loss": training.get("critic_loss", 0.0),
1042
+ "avn_loss": training.get("avn_loss", 0.0),
1043
+ "buy_count": buy,
1044
+ "sell_count": sell,
1045
+ "last_updated": snap.get("last_updated", _axrvi_rankings_ts),
1046
+ })
1047
+ return merged
1048
+
1049
+ # ── Path 2: fallback hub-snapshot scoring ────────────────────────────────
1050
  ranked: List[dict] = []
1051
  for name, snap in manager.get_all_snapshots().items():
1052
  training = snap.get("training", {})
 
1061
  "rank": 0,
1062
  "space_name": name,
1063
  "score": score,
1064
+ "final_priority": score,
1065
  "signal_confidence": round(sig_conf, 6),
1066
  "avn_accuracy": round(avn_acc, 6),
1067
  "dominant_signal": voting.get("dominant_signal", "NEUTRAL"),
 
1072
  "critic_loss": training.get("critic_loss", 0.0),
1073
  "avn_loss": training.get("avn_loss", 0.0),
1074
  "last_updated": snap.get("last_updated", 0.0),
1075
+ "epistemic_std": 0.0,
1076
  })
1077
  ranked.sort(key=lambda r: r["score"], reverse=True)
1078
  for i, r in enumerate(ranked):
 
1101
  )
1102
 
1103
 
1104
+ @app.post("/api/axrvi/rankings")
1105
+ async def receive_axrvi_rankings(request: Request):
1106
+ """
1107
+ Called by the Executo ranker after every rank_and_gate() cycle (~5 s).
1108
+ Stores the live AXRVI-scored ranking list so _compute_rankings() can serve
1109
+ it from /api/state instead of the stale hub-snapshot vote-ratio fallback.
1110
+
1111
+ Expected body:
1112
+ {"rankings": [{"space_name": "V75", "score": 0.24, "rank": 1, ...}, ...]}
1113
+ """
1114
+ global _axrvi_rankings, _axrvi_rankings_ts
1115
+ try:
1116
+ body = await request.json()
1117
+ except Exception as e:
1118
+ return JSONResponse({"ok": False, "error": f"Bad JSON: {e}"}, status_code=400)
1119
+
1120
+ rankings = body.get("rankings", [])
1121
+ if not isinstance(rankings, list):
1122
+ return JSONResponse({"ok": False, "error": "rankings must be a list"}, status_code=400)
1123
+
1124
+ _axrvi_rankings = rankings
1125
+ _axrvi_rankings_ts = time.time()
1126
+ logger.debug(
1127
+ f"[AXRVI Rankings] Received {len(rankings)} assets | "
1128
+ f"top={rankings[0].get('space_name','?')} score={rankings[0].get('score',0):.4f}"
1129
+ if rankings else "[AXRVI Rankings] Received empty list"
1130
+ )
1131
+ return JSONResponse({"ok": True, "count": len(rankings), "ts": _axrvi_rankings_ts})
1132
+
1133
+
1134
  @app.get("/api/state")
1135
  async def api_state():
1136
  """Full dashboard state polled by hub_dashboard.html every 2 s."""
 
1162
  _START_TIME = time.time()
1163
 
1164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1165
  # ══════════════════════════════════════════════════════════════════════════════════════
1166
  # SECTION 9 β€” ENTRY POINT
1167
  # ══════════════════════════════════════════════════════════════════════════════════════