ianalin123 Claude Sonnet 4.6 commited on
Commit
6cf63a9
·
1 Parent(s): 5eca717

feat(server): add training broadcast server and Colab training FastAPI app

Browse files

- TrainingBroadcastServer: fire-and-forget WS hub, stores full step history
in episode registry for /episode/replay, fixes publish() to use stored
event loop (asyncio.run_coroutine_threadsafe from training threads)
- server/app.py: new Colab training server with /ws/training, /targets,
/episode/demo, /episode/replay/:ep_id; mounts React build + viewer

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. server/app.py +162 -0
  2. server/training_broadcast.py +20 -11
server/app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/app.py — Training WebSocket server for Colab environment.
3
+
4
+ Provides /ws/training for live streaming of RL training episodes to browsers.
5
+ Mount at a publicly accessible URL in Colab (e.g., via ngrok or Colab's proxy).
6
+
7
+ Usage in training:
8
+ from server.app import broadcast
9
+ broadcast.publish(episode_id, {"type": "episode_update", ...})
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException, WebSocket
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import HTMLResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+
21
+ from server.training_broadcast import TrainingBroadcastServer
22
+
23
+ app = FastAPI(title="Optigami Training Server", version="1.0")
24
+
25
+ # Allow cross-origin connections (Colab public URL → browser)
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Global broadcast server — import and use from training code
35
+ broadcast = TrainingBroadcastServer()
36
+
37
+
38
+ @app.on_event("startup")
39
+ async def _store_loop() -> None:
40
+ """Capture the asyncio event loop so training threads can schedule coroutines."""
41
+ import asyncio
42
+ broadcast._loop = asyncio.get_running_loop()
43
+
44
+
45
+ @app.websocket("/ws/training")
46
+ async def training_ws(websocket: WebSocket) -> None:
47
+ """Spectator WebSocket endpoint. Viewers connect here to watch training."""
48
+ await broadcast.connect_spectator(websocket)
49
+
50
+
51
+ @app.get("/health")
52
+ def health() -> dict:
53
+ return {
54
+ "status": "ok",
55
+ "spectators": broadcast.spectator_count,
56
+ "active_episodes": broadcast.active_episodes,
57
+ }
58
+
59
+
60
+ # ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
61
+
62
+ @app.get("/targets")
63
+ def get_targets() -> dict:
64
+ from server.tasks import available_task_names, get_task_by_name
65
+ return {
66
+ name: {
67
+ "name": name,
68
+ "level": t["difficulty"],
69
+ "description": t.get("description", ""),
70
+ "n_creases": t.get("max_folds", 3),
71
+ "difficulty": t["difficulty"],
72
+ "material": t.get("material", "paper"),
73
+ }
74
+ for name in available_task_names()
75
+ if (t := get_task_by_name(name))
76
+ }
77
+
78
+
79
+ _DEMO_SEQUENCES: dict[str, list[dict]] = {
80
+ "half_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}],
81
+ "quarter_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
82
+ {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
83
+ "letter_fold": [{"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
84
+ {"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0}],
85
+ "map_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
86
+ {"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
87
+ "solar_panel": [{"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
88
+ {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
89
+ {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0}],
90
+ }
91
+
92
+
93
+ @app.get("/episode/demo")
94
+ def demo_episode(target: str = "half_fold") -> dict:
95
+ from server.origami_environment import OrigamiEnvironment
96
+ from server.models import OrigamiAction as NewAction
97
+ from server.tasks import get_task_by_name
98
+
99
+ folds = _DEMO_SEQUENCES.get(target, _DEMO_SEQUENCES["half_fold"])
100
+ env = OrigamiEnvironment()
101
+ obs = env.reset(task_name=target)
102
+ steps: list[dict] = []
103
+
104
+ for i, fold_dict in enumerate(folds):
105
+ action = NewAction(
106
+ fold_type=fold_dict["type"],
107
+ fold_line=fold_dict["line"],
108
+ fold_angle=float(fold_dict.get("angle", 180.0)),
109
+ )
110
+ obs = env.step(action)
111
+ steps.append({"step": i + 1, "fold": fold_dict,
112
+ "paper_state": obs.paper_state, "metrics": obs.metrics,
113
+ "done": obs.done})
114
+ if obs.done:
115
+ break
116
+
117
+ return {"task_name": target, "task": get_task_by_name(target) or {},
118
+ "steps": steps, "final_metrics": obs.metrics if steps else {}}
119
+
120
+
121
+ @app.get("/episode/replay/{ep_id}")
122
+ def replay_episode(ep_id: str) -> dict:
123
+ """Return a stored training episode in the same format as /episode/demo."""
124
+ from server.tasks import get_task_by_name
125
+ ep = broadcast._registry.get(ep_id)
126
+ if not ep:
127
+ raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
128
+ return {
129
+ "task_name": ep.task_name,
130
+ "task": get_task_by_name(ep.task_name) or {},
131
+ "steps": ep.steps,
132
+ "final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
133
+ }
134
+
135
+
136
+ # ── Static files — viewer first, then React app (LAST, catch-all) ──
137
+
138
+ _VIEWER_DIR = Path(__file__).resolve().parent.parent / "viewer"
139
+ _BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
140
+
141
+ if _VIEWER_DIR.exists():
142
+ app.mount("/viewer", StaticFiles(directory=str(_VIEWER_DIR), html=True), name="viewer")
143
+
144
+
145
+ if _BUILD_DIR.exists():
146
+ app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="react")
147
+ else:
148
+ @app.get("/", include_in_schema=False)
149
+ def _no_build() -> HTMLResponse:
150
+ return HTMLResponse(
151
+ "<p>React build not found. Run <code>npm run build</code> in the frontend directory.</p>"
152
+ "<p>Training viewer: <a href='/viewer/training.html'>/viewer/training.html</a></p>"
153
+ )
154
+
155
+
156
+ def run(host: str = "0.0.0.0", port: int = 9001) -> None:
157
+ """Start the training server. Call from Colab notebook."""
158
+ uvicorn.run(app, host=host, port=port)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ run()
server/training_broadcast.py CHANGED
@@ -27,6 +27,7 @@ class EpisodeInfo:
27
  observation: dict = field(default_factory=dict)
28
  metrics: dict = field(default_factory=dict)
29
  fold_history: list = field(default_factory=list)
 
30
  score: Optional[float] = None
31
  final_metrics: Optional[dict] = None
32
 
@@ -50,17 +51,13 @@ class TrainingBroadcastServer:
50
  def publish(self, episode_id: str, data: dict) -> None:
51
  """Fire-and-forget: push an update from the training process.
52
 
53
- Safe to call from any thread. If no event loop is running, logs and returns.
 
54
  """
55
- try:
56
- loop = asyncio.get_event_loop()
57
- if loop.is_running():
58
- asyncio.ensure_future(self._async_publish(episode_id, data), loop=loop)
59
- else:
60
- loop.run_until_complete(self._async_publish(episode_id, data))
61
- except RuntimeError:
62
- # No event loop — training without server
63
- pass
64
 
65
  async def _async_publish(self, episode_id: str, data: dict) -> None:
66
  msg_type = data.get("type", "episode_update")
@@ -91,12 +88,24 @@ class TrainingBroadcastServer:
91
  ep.score = data.get("score")
92
  ep.final_metrics = data.get("final_metrics")
93
  else:
94
- ep.step = data.get("step", ep.step)
 
95
  ep.status = "running"
96
  obs = data.get("observation", {})
97
  ep.observation = obs
98
  ep.metrics = obs.get("metrics", {})
99
  ep.fold_history = obs.get("fold_history", ep.fold_history)
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  await self._broadcast({"episode_id": episode_id, **data})
102
 
 
27
  observation: dict = field(default_factory=dict)
28
  metrics: dict = field(default_factory=dict)
29
  fold_history: list = field(default_factory=list)
30
+ steps: list = field(default_factory=list) # full step history for replay
31
  score: Optional[float] = None
32
  final_metrics: Optional[dict] = None
33
 
 
51
  def publish(self, episode_id: str, data: dict) -> None:
52
  """Fire-and-forget: push an update from the training process.
53
 
54
+ Safe to call from any thread. Schedules onto the stored event loop
55
+ (set by the FastAPI startup handler). No-op if no loop is available.
56
  """
57
+ loop = self._loop
58
+ if loop is None or loop.is_closed():
59
+ return
60
+ asyncio.run_coroutine_threadsafe(self._async_publish(episode_id, data), loop)
 
 
 
 
 
61
 
62
  async def _async_publish(self, episode_id: str, data: dict) -> None:
63
  msg_type = data.get("type", "episode_update")
 
88
  ep.score = data.get("score")
89
  ep.final_metrics = data.get("final_metrics")
90
  else:
91
+ step_num = data.get("step", ep.step)
92
+ ep.step = step_num
93
  ep.status = "running"
94
  obs = data.get("observation", {})
95
  ep.observation = obs
96
  ep.metrics = obs.get("metrics", {})
97
  ep.fold_history = obs.get("fold_history", ep.fold_history)
98
+ # Accumulate full step history for /episode/replay
99
+ if step_num > 0:
100
+ fold_hist = obs.get("fold_history", [])
101
+ latest_fold = fold_hist[-1] if fold_hist else {}
102
+ ep.steps.append({
103
+ "step": step_num,
104
+ "fold": latest_fold,
105
+ "paper_state": obs.get("paper_state", {}),
106
+ "metrics": obs.get("metrics", {}),
107
+ "done": obs.get("done", False),
108
+ })
109
 
110
  await self._broadcast({"episode_id": episode_id, **data})
111