Spaces:
Sleeping
Sleeping
File size: 7,783 Bytes
1e49495 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | """
TrainingBroadcastServer β fire-and-forget broadcast hub for live training viewer.
The RL training process calls publish() after each env.step().
Spectator browsers connect via /ws/training WebSocket.
Broadcast is async and non-blocking: if no viewers are connected, observations are dropped.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
from fastapi import WebSocket, WebSocketDisconnect
logger = logging.getLogger(__name__)
@dataclass
class EpisodeInfo:
episode_id: str
task_name: str
status: str = "running" # "running" | "done" | "timeout" | "error"
step: int = 0
observation: dict = field(default_factory=dict)
metrics: dict = field(default_factory=dict)
fold_history: list = field(default_factory=list)
steps: list = field(default_factory=list) # full step history for replay
score: Optional[float] = None
final_metrics: Optional[dict] = None
class TrainingBroadcastServer:
"""Central hub for broadcasting RL training observations to spectator WebSockets.
Thread-safe: publish() can be called from training threads (ThreadPoolExecutor).
WebSocket handlers run in the asyncio event loop.
"""
def __init__(self) -> None:
self._spectators: list[WebSocket] = []
self._registry: dict[str, EpisodeInfo] = {}
self._batch_id: int = 0
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._lock = asyncio.Lock()
# ββ Episode publishing (called from training thread / async context) ββ
def publish(self, episode_id: str, data: dict) -> None:
"""Fire-and-forget: push an update from the training process.
Safe to call from any thread. Schedules onto the stored event loop
(set by the FastAPI startup handler). No-op if no loop is available.
"""
loop = self._loop
if loop is None or loop.is_closed():
return
asyncio.run_coroutine_threadsafe(self._async_publish(episode_id, data), loop)
async def _async_publish(self, episode_id: str, data: dict) -> None:
msg_type = data.get("type", "episode_update")
async with self._lock:
if msg_type == "batch_start":
self._batch_id = data.get("batch_id", self._batch_id + 1)
self._registry.clear()
await self._broadcast(data)
return
if msg_type == "batch_done":
await self._broadcast(data)
return
if msg_type == "training_done":
await self._broadcast(data)
return
# episode_update or episode_done
ep = self._registry.setdefault(
episode_id,
EpisodeInfo(episode_id=episode_id, task_name=data.get("task_name", "")),
)
if msg_type == "episode_done":
ep.status = data.get("status", "done")
ep.score = data.get("score")
ep.final_metrics = data.get("final_metrics")
else:
step_num = data.get("step", ep.step)
ep.step = step_num
ep.status = "running"
obs = data.get("observation", {})
ep.observation = obs
ep.metrics = obs.get("metrics", {})
ep.fold_history = obs.get("fold_history", ep.fold_history)
# Accumulate full step history for /episode/replay
if step_num > 0:
fold_hist = obs.get("fold_history", [])
latest_fold = fold_hist[-1] if fold_hist else {}
ep.steps.append({
"step": step_num,
"fold": latest_fold,
"paper_state": obs.get("paper_state", {}),
"metrics": obs.get("metrics", {}),
"done": obs.get("done", False),
})
await self._broadcast({"episode_id": episode_id, **data})
# ββ Spectator management ββ
async def connect_spectator(self, websocket: WebSocket) -> None:
"""Accept a new viewer WebSocket and serve it until disconnect."""
await websocket.accept()
async with self._lock:
self._spectators.append(websocket)
# Send current registry snapshot immediately
await self._send_registry(websocket)
try:
while True:
# Viewers are read-only; drain any incoming messages (pings etc)
await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
except (WebSocketDisconnect, asyncio.TimeoutError, Exception):
pass
finally:
await self.disconnect_spectator(websocket)
async def disconnect_spectator(self, websocket: WebSocket) -> None:
async with self._lock:
self._spectators = [s for s in self._spectators if s is not websocket]
# ββ Batch control ββ
async def start_batch(self, batch_id: int, num_episodes: int, prompt_index: int = 0) -> None:
"""Call before starting a new training batch."""
data = {
"type": "batch_start",
"batch_id": batch_id,
"num_episodes": num_episodes,
"prompt_index": prompt_index,
}
await self._async_publish("__batch__", data)
async def finish_batch(
self,
batch_id: int,
scores: list[float],
best_episode_id: str = "",
) -> None:
"""Call after all episodes in a batch complete."""
data = {
"type": "batch_done",
"batch_id": batch_id,
"scores": scores,
"best_episode_id": best_episode_id,
"avg_score": sum(scores) / len(scores) if scores else 0.0,
}
await self._async_publish("__batch__", data)
async def clear_batch(self) -> None:
"""Reset episode registry for next batch."""
async with self._lock:
self._registry.clear()
# ββ Internals ββ
async def _broadcast(self, message: dict) -> None:
"""Send message to all spectators, removing dead connections."""
if not self._spectators:
return
payload = json.dumps(message, default=str)
dead: list[WebSocket] = []
for ws in list(self._spectators):
try:
await ws.send_text(payload)
except Exception:
dead.append(ws)
for ws in dead:
self._spectators = [s for s in self._spectators if s is not ws]
async def _send_registry(self, websocket: WebSocket) -> None:
"""Send the full episode registry to a newly connected viewer."""
async with self._lock:
episodes = {
ep_id: {
"status": ep.status,
"task": ep.task_name,
"step": ep.step,
"observation": ep.observation,
"metrics": ep.metrics,
"score": ep.score,
}
for ep_id, ep in self._registry.items()
}
payload = {
"type": "registry",
"batch_id": self._batch_id,
"episodes": episodes,
}
try:
await websocket.send_text(json.dumps(payload, default=str))
except Exception:
pass
@property
def spectator_count(self) -> int:
return len(self._spectators)
@property
def active_episodes(self) -> int:
return sum(1 for ep in self._registry.values() if ep.status == "running")
|