Spaces:
Sleeping
Sleeping
Commit ·
6cf63a9
1
Parent(s): 5eca717
feat(server): add training broadcast server and Colab training FastAPI app
Browse files- TrainingBroadcastServer: fire-and-forget WS hub, stores full step history
in episode registry for /episode/replay, fixes publish() to use stored
event loop (asyncio.run_coroutine_threadsafe from training threads)
- server/app.py: new Colab training server with /ws/training, /targets,
/episode/demo, /episode/replay/:ep_id; mounts React build + viewer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- server/app.py +162 -0
- server/training_broadcast.py +20 -11
server/app.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/app.py — Training WebSocket server for Colab environment.
|
| 3 |
+
|
| 4 |
+
Provides /ws/training for live streaming of RL training episodes to browsers.
|
| 5 |
+
Mount at a publicly accessible URL in Colab (e.g., via ngrok or Colab's proxy).
|
| 6 |
+
|
| 7 |
+
Usage in training:
|
| 8 |
+
from server.app import broadcast
|
| 9 |
+
broadcast.publish(episode_id, {"type": "episode_update", ...})
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import uvicorn
|
| 16 |
+
from fastapi import FastAPI, HTTPException, WebSocket
|
| 17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 18 |
+
from fastapi.responses import HTMLResponse
|
| 19 |
+
from fastapi.staticfiles import StaticFiles
|
| 20 |
+
|
| 21 |
+
from server.training_broadcast import TrainingBroadcastServer
|
| 22 |
+
|
| 23 |
+
app = FastAPI(title="Optigami Training Server", version="1.0")
|
| 24 |
+
|
| 25 |
+
# Allow cross-origin connections (Colab public URL → browser)
|
| 26 |
+
app.add_middleware(
|
| 27 |
+
CORSMiddleware,
|
| 28 |
+
allow_origins=["*"],
|
| 29 |
+
allow_credentials=True,
|
| 30 |
+
allow_methods=["*"],
|
| 31 |
+
allow_headers=["*"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Global broadcast server — import and use from training code
|
| 35 |
+
broadcast = TrainingBroadcastServer()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.on_event("startup")
|
| 39 |
+
async def _store_loop() -> None:
|
| 40 |
+
"""Capture the asyncio event loop so training threads can schedule coroutines."""
|
| 41 |
+
import asyncio
|
| 42 |
+
broadcast._loop = asyncio.get_running_loop()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.websocket("/ws/training")
|
| 46 |
+
async def training_ws(websocket: WebSocket) -> None:
|
| 47 |
+
"""Spectator WebSocket endpoint. Viewers connect here to watch training."""
|
| 48 |
+
await broadcast.connect_spectator(websocket)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@app.get("/health")
|
| 52 |
+
def health() -> dict:
|
| 53 |
+
return {
|
| 54 |
+
"status": "ok",
|
| 55 |
+
"spectators": broadcast.spectator_count,
|
| 56 |
+
"active_episodes": broadcast.active_episodes,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
|
| 61 |
+
|
| 62 |
+
@app.get("/targets")
|
| 63 |
+
def get_targets() -> dict:
|
| 64 |
+
from server.tasks import available_task_names, get_task_by_name
|
| 65 |
+
return {
|
| 66 |
+
name: {
|
| 67 |
+
"name": name,
|
| 68 |
+
"level": t["difficulty"],
|
| 69 |
+
"description": t.get("description", ""),
|
| 70 |
+
"n_creases": t.get("max_folds", 3),
|
| 71 |
+
"difficulty": t["difficulty"],
|
| 72 |
+
"material": t.get("material", "paper"),
|
| 73 |
+
}
|
| 74 |
+
for name in available_task_names()
|
| 75 |
+
if (t := get_task_by_name(name))
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_DEMO_SEQUENCES: dict[str, list[dict]] = {
|
| 80 |
+
"half_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}],
|
| 81 |
+
"quarter_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 82 |
+
{"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
|
| 83 |
+
"letter_fold": [{"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
|
| 84 |
+
{"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0}],
|
| 85 |
+
"map_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 86 |
+
{"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
|
| 87 |
+
"solar_panel": [{"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
|
| 88 |
+
{"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
|
| 89 |
+
{"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0}],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@app.get("/episode/demo")
|
| 94 |
+
def demo_episode(target: str = "half_fold") -> dict:
|
| 95 |
+
from server.origami_environment import OrigamiEnvironment
|
| 96 |
+
from server.models import OrigamiAction as NewAction
|
| 97 |
+
from server.tasks import get_task_by_name
|
| 98 |
+
|
| 99 |
+
folds = _DEMO_SEQUENCES.get(target, _DEMO_SEQUENCES["half_fold"])
|
| 100 |
+
env = OrigamiEnvironment()
|
| 101 |
+
obs = env.reset(task_name=target)
|
| 102 |
+
steps: list[dict] = []
|
| 103 |
+
|
| 104 |
+
for i, fold_dict in enumerate(folds):
|
| 105 |
+
action = NewAction(
|
| 106 |
+
fold_type=fold_dict["type"],
|
| 107 |
+
fold_line=fold_dict["line"],
|
| 108 |
+
fold_angle=float(fold_dict.get("angle", 180.0)),
|
| 109 |
+
)
|
| 110 |
+
obs = env.step(action)
|
| 111 |
+
steps.append({"step": i + 1, "fold": fold_dict,
|
| 112 |
+
"paper_state": obs.paper_state, "metrics": obs.metrics,
|
| 113 |
+
"done": obs.done})
|
| 114 |
+
if obs.done:
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
return {"task_name": target, "task": get_task_by_name(target) or {},
|
| 118 |
+
"steps": steps, "final_metrics": obs.metrics if steps else {}}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@app.get("/episode/replay/{ep_id}")
|
| 122 |
+
def replay_episode(ep_id: str) -> dict:
|
| 123 |
+
"""Return a stored training episode in the same format as /episode/demo."""
|
| 124 |
+
from server.tasks import get_task_by_name
|
| 125 |
+
ep = broadcast._registry.get(ep_id)
|
| 126 |
+
if not ep:
|
| 127 |
+
raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
|
| 128 |
+
return {
|
| 129 |
+
"task_name": ep.task_name,
|
| 130 |
+
"task": get_task_by_name(ep.task_name) or {},
|
| 131 |
+
"steps": ep.steps,
|
| 132 |
+
"final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ── Static files — viewer first, then React app (LAST, catch-all) ──
|
| 137 |
+
|
| 138 |
+
_VIEWER_DIR = Path(__file__).resolve().parent.parent / "viewer"
|
| 139 |
+
_BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
|
| 140 |
+
|
| 141 |
+
if _VIEWER_DIR.exists():
|
| 142 |
+
app.mount("/viewer", StaticFiles(directory=str(_VIEWER_DIR), html=True), name="viewer")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if _BUILD_DIR.exists():
|
| 146 |
+
app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="react")
|
| 147 |
+
else:
|
| 148 |
+
@app.get("/", include_in_schema=False)
|
| 149 |
+
def _no_build() -> HTMLResponse:
|
| 150 |
+
return HTMLResponse(
|
| 151 |
+
"<p>React build not found. Run <code>npm run build</code> in the frontend directory.</p>"
|
| 152 |
+
"<p>Training viewer: <a href='/viewer/training.html'>/viewer/training.html</a></p>"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run(host: str = "0.0.0.0", port: int = 9001) -> None:
|
| 157 |
+
"""Start the training server. Call from Colab notebook."""
|
| 158 |
+
uvicorn.run(app, host=host, port=port)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
run()
|
server/training_broadcast.py
CHANGED
|
@@ -27,6 +27,7 @@ class EpisodeInfo:
|
|
| 27 |
observation: dict = field(default_factory=dict)
|
| 28 |
metrics: dict = field(default_factory=dict)
|
| 29 |
fold_history: list = field(default_factory=list)
|
|
|
|
| 30 |
score: Optional[float] = None
|
| 31 |
final_metrics: Optional[dict] = None
|
| 32 |
|
|
@@ -50,17 +51,13 @@ class TrainingBroadcastServer:
|
|
| 50 |
def publish(self, episode_id: str, data: dict) -> None:
|
| 51 |
"""Fire-and-forget: push an update from the training process.
|
| 52 |
|
| 53 |
-
Safe to call from any thread.
|
|
|
|
| 54 |
"""
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
else:
|
| 60 |
-
loop.run_until_complete(self._async_publish(episode_id, data))
|
| 61 |
-
except RuntimeError:
|
| 62 |
-
# No event loop — training without server
|
| 63 |
-
pass
|
| 64 |
|
| 65 |
async def _async_publish(self, episode_id: str, data: dict) -> None:
|
| 66 |
msg_type = data.get("type", "episode_update")
|
|
@@ -91,12 +88,24 @@ class TrainingBroadcastServer:
|
|
| 91 |
ep.score = data.get("score")
|
| 92 |
ep.final_metrics = data.get("final_metrics")
|
| 93 |
else:
|
| 94 |
-
|
|
|
|
| 95 |
ep.status = "running"
|
| 96 |
obs = data.get("observation", {})
|
| 97 |
ep.observation = obs
|
| 98 |
ep.metrics = obs.get("metrics", {})
|
| 99 |
ep.fold_history = obs.get("fold_history", ep.fold_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
await self._broadcast({"episode_id": episode_id, **data})
|
| 102 |
|
|
|
|
| 27 |
observation: dict = field(default_factory=dict)
|
| 28 |
metrics: dict = field(default_factory=dict)
|
| 29 |
fold_history: list = field(default_factory=list)
|
| 30 |
+
steps: list = field(default_factory=list) # full step history for replay
|
| 31 |
score: Optional[float] = None
|
| 32 |
final_metrics: Optional[dict] = None
|
| 33 |
|
|
|
|
| 51 |
def publish(self, episode_id: str, data: dict) -> None:
|
| 52 |
"""Fire-and-forget: push an update from the training process.
|
| 53 |
|
| 54 |
+
Safe to call from any thread. Schedules onto the stored event loop
|
| 55 |
+
(set by the FastAPI startup handler). No-op if no loop is available.
|
| 56 |
"""
|
| 57 |
+
loop = self._loop
|
| 58 |
+
if loop is None or loop.is_closed():
|
| 59 |
+
return
|
| 60 |
+
asyncio.run_coroutine_threadsafe(self._async_publish(episode_id, data), loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
async def _async_publish(self, episode_id: str, data: dict) -> None:
|
| 63 |
msg_type = data.get("type", "episode_update")
|
|
|
|
| 88 |
ep.score = data.get("score")
|
| 89 |
ep.final_metrics = data.get("final_metrics")
|
| 90 |
else:
|
| 91 |
+
step_num = data.get("step", ep.step)
|
| 92 |
+
ep.step = step_num
|
| 93 |
ep.status = "running"
|
| 94 |
obs = data.get("observation", {})
|
| 95 |
ep.observation = obs
|
| 96 |
ep.metrics = obs.get("metrics", {})
|
| 97 |
ep.fold_history = obs.get("fold_history", ep.fold_history)
|
| 98 |
+
# Accumulate full step history for /episode/replay
|
| 99 |
+
if step_num > 0:
|
| 100 |
+
fold_hist = obs.get("fold_history", [])
|
| 101 |
+
latest_fold = fold_hist[-1] if fold_hist else {}
|
| 102 |
+
ep.steps.append({
|
| 103 |
+
"step": step_num,
|
| 104 |
+
"fold": latest_fold,
|
| 105 |
+
"paper_state": obs.get("paper_state", {}),
|
| 106 |
+
"metrics": obs.get("metrics", {}),
|
| 107 |
+
"done": obs.get("done", False),
|
| 108 |
+
})
|
| 109 |
|
| 110 |
await self._broadcast({"episode_id": episode_id, **data})
|
| 111 |
|