ianalin123 commited on
Commit
0bcd0b1
Β·
1 Parent(s): ca61c8d

fix: rename server.py to server_legacy.py, add server/ package

Browse files

server/ package (new engine-based environment) was shadowing server.py,
causing 'Attribute app not found in module server'. Renamed the old
monolithic server.py to server_legacy.py to resolve the conflict.

For local dev use: uvicorn openenv_server.app:app --reload

server/__init__.py ADDED
File without changes
server/models.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv Pydantic models for the origami RL environment.
3
+
4
+ OrigamiAction β€” one fold per step
5
+ OrigamiObservation β€” everything the LLM and Three.js viewer need
6
+ OrigamiState β€” server-side episode tracking
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import Field
13
+
14
+ from openenv.core.env_server.types import Action, Observation, State
15
+
16
+
17
+ class OrigamiAction(Action):
18
+ """One fold operation sent by the client each step."""
19
+
20
+ fold_type: str = Field(
21
+ default="valley",
22
+ description="'valley' | 'mountain' | 'pleat' | 'crimp' | 'stop'",
23
+ )
24
+ fold_line: dict[str, list[float]] = Field(
25
+ default_factory=lambda: {"start": [0.0, 0.5], "end": [1.0, 0.5]},
26
+ description="{'start': [x, y], 'end': [x, y]} normalized 0-1",
27
+ )
28
+ fold_angle: float = Field(
29
+ default=180.0,
30
+ description="Fold angle in degrees, 0-180",
31
+ )
32
+ layer_select: str = Field(
33
+ default="all",
34
+ description="'all' | 'top' | 'bottom'",
35
+ )
36
+
37
+
38
+ class OrigamiObservation(Observation):
39
+ """Everything the LLM and Three.js viewer need.
40
+
41
+ paper_state contains FOLD-compatible geometry + physics data.
42
+ metrics contains all computed quality metrics.
43
+ No render_urls β€” the browser renders from paper_state directly.
44
+ """
45
+
46
+ task: dict[str, Any] = Field(default_factory=dict)
47
+ paper_state: dict[str, Any] = Field(default_factory=dict)
48
+ metrics: dict[str, Any] = Field(default_factory=dict)
49
+ fold_history: list[dict[str, Any]] = Field(default_factory=list)
50
+ error: Optional[str] = Field(default=None)
51
+
52
+
53
+ class OrigamiState(State):
54
+ """Server-side episode tracking."""
55
+
56
+ task_name: str = Field(default="")
57
+ num_folds_applied: int = Field(default=0)
58
+ is_valid: bool = Field(default=True)
59
+ total_reward: float = Field(default=0.0)
server/origami_environment.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OrigamiEnvironment β€” OpenEnv environment wrapping the origami physics engine.
3
+
4
+ Implements reset() / step() / state following the OpenEnv interface.
5
+ Engine (physics, fold, validation, metrics) lives in engine/.
6
+ No server-side image rendering β€” paper_state contains all geometry data.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import uuid
13
+ from typing import Any, Optional
14
+
15
+ from openenv.core.env_server.interfaces import Environment
16
+
17
+ from engine.paper import Paper
18
+ from engine.fold_engine import apply_fold
19
+ from engine.physics import simulate
20
+ from engine.validation import validate_state
21
+ from engine.metrics import compute_all_metrics
22
+ from server.models import OrigamiAction, OrigamiObservation, OrigamiState
23
+ from server.tasks import get_task_by_name, sample_task
24
+
25
+
26
+ def _get_material(name: str):
27
+ """Get material by name, falling back to paper."""
28
+ try:
29
+ from engine.materials import get_material
30
+ return get_material(name)
31
+ except Exception:
32
+ from engine.materials import get_material
33
+ return get_material("paper")
34
+
35
+
36
+ class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
37
+ """Origami folding RL environment.
38
+
39
+ Each episode: agent receives paper_state + task, applies folds one at a
40
+ time via step(), receives metrics + reward, ends with 'stop' action or
41
+ when max_folds is reached.
42
+ """
43
+
44
+ SUPPORTS_CONCURRENT_SESSIONS = False
45
+
46
+ def __init__(self, **kwargs):
47
+ super().__init__(**kwargs)
48
+ self._paper: Optional[Paper] = None
49
+ self._task: Optional[dict] = None
50
+ self._fold_history: list[dict] = []
51
+ self._metrics: dict = {}
52
+ self._validation: dict = {}
53
+ self._error: Optional[str] = None
54
+ self._episode_id: Optional[str] = None
55
+ self._step_count: int = 0
56
+ self._total_reward: float = 0.0
57
+
58
+ # ── reset ─────────────────────────────────────────────────────────
59
+
60
+ def reset(
61
+ self,
62
+ seed: Optional[int] = None,
63
+ episode_id: Optional[str] = None,
64
+ **kwargs: Any,
65
+ ) -> OrigamiObservation:
66
+ self._episode_id = episode_id or str(uuid.uuid4())
67
+ self._step_count = 0
68
+ self._fold_history = []
69
+ self._error = None
70
+ self._total_reward = 0.0
71
+
72
+ # Select task
73
+ task_name = kwargs.get("task_name")
74
+ if task_name:
75
+ self._task = get_task_by_name(task_name)
76
+ if not self._task:
77
+ self._task = sample_task(seed=seed)
78
+
79
+ # Create flat sheet
80
+ mat = _get_material(self._task["material"])
81
+ self._paper = Paper.create_flat_sheet(
82
+ width=self._task["width"],
83
+ height=self._task["height"],
84
+ material=mat,
85
+ )
86
+
87
+ # Initial validation + metrics (no physics needed for flat sheet)
88
+ self._validation = validate_state(self._paper)
89
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
90
+
91
+ return self._make_observation(done=False, reward=None)
92
+
93
+ # ── step ──────────────────────────────────────────────────────────
94
+
95
+ def step(
96
+ self,
97
+ action: OrigamiAction,
98
+ timeout_s: Optional[float] = None,
99
+ **kwargs: Any,
100
+ ) -> OrigamiObservation:
101
+ if self._paper is None or self._task is None:
102
+ return self._make_observation(done=True, reward=-5.0)
103
+
104
+ self._step_count += 1
105
+ self._error = None
106
+
107
+ # ── Stop action ───────────────────────────────────────────────
108
+ if action.fold_type == "stop":
109
+ return self._finalize_episode()
110
+
111
+ # ── Build fold dict ───────────────────────────────────────────
112
+ fold_dict = {
113
+ "type": action.fold_type,
114
+ "line": action.fold_line,
115
+ "angle": action.fold_angle,
116
+ }
117
+
118
+ # ── Apply fold ────────────────────────────────────────────────
119
+ new_paper, err = apply_fold(self._paper, fold_dict)
120
+ if err:
121
+ self._error = err
122
+ return self._make_observation(done=True, reward=-5.0)
123
+
124
+ self._paper = new_paper
125
+ self._fold_history.append({**fold_dict, "step": self._step_count})
126
+
127
+ # ── Physics relaxation ────────────────────────────────────────
128
+ try:
129
+ self._paper = simulate(self._paper, fold_percent=1.0)
130
+ except Exception as exc:
131
+ self._error = f"Physics failed: {exc}"
132
+ # Continue β€” don't abort episode on physics failure
133
+
134
+ # ── Validate ──────────────────────────────────────────────────
135
+ self._validation = validate_state(self._paper)
136
+
137
+ # ── Metrics ───────────────────────────────────────────────────
138
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
139
+
140
+ # ── Check termination ─────────────────────────────────────────
141
+ max_folds = self._task.get("max_folds", 50)
142
+ if self._step_count >= max_folds:
143
+ return self._finalize_episode()
144
+
145
+ if self._validation.get("self_intersections", 0) > 0:
146
+ self._error = "Self-intersection detected"
147
+ return self._finalize_episode()
148
+
149
+ return self._make_observation(done=False, reward=None)
150
+
151
+ # ── state ─────────────────────────────────────────────────────────
152
+
153
+ @property
154
+ def state(self) -> OrigamiState:
155
+ return OrigamiState(
156
+ episode_id=self._episode_id,
157
+ step_count=self._step_count,
158
+ task_name=self._task.get("name", "") if self._task else "",
159
+ num_folds_applied=len(self._fold_history),
160
+ is_valid=self._metrics.get("is_valid", True),
161
+ total_reward=self._total_reward,
162
+ )
163
+
164
+ # ── internals ─────────────────────────────────────────────────────
165
+
166
+ def _finalize_episode(self) -> OrigamiObservation:
167
+ reward = self._compute_reward()
168
+ self._total_reward = reward
169
+ return self._make_observation(done=True, reward=reward)
170
+
171
+ def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
172
+ return OrigamiObservation(
173
+ done=done,
174
+ reward=reward,
175
+ task=self._task or {},
176
+ paper_state=self._paper.to_observation_dict() if self._paper else {},
177
+ metrics=self._metrics,
178
+ fold_history=self._fold_history,
179
+ error=self._error,
180
+ )
181
+
182
+ def _compute_reward(self) -> float:
183
+ m = self._metrics
184
+ reward = 0.0
185
+
186
+ # Compactness is the main signal
187
+ reward += m.get("compactness", 0.0) * 20.0
188
+
189
+ # Bonus for fitting in target box
190
+ if m.get("fits_target_box", False):
191
+ reward += 10.0
192
+
193
+ # Bonus for deployability (if task requires it)
194
+ if m.get("is_deployable", False):
195
+ reward += 5.0
196
+
197
+ # Penalties for violations
198
+ reward -= m.get("kawasaki_violations", 0) * 2.0
199
+ reward -= m.get("maekawa_violations", 0) * 2.0
200
+ reward -= m.get("self_intersections", 0) * 5.0
201
+
202
+ # Penalty for too many folds (encourage efficiency)
203
+ reward -= m.get("fold_count", 0) * 0.5
204
+
205
+ # Penalty for exceeding material strain limit
206
+ max_strain = m.get("max_strain", 0.0)
207
+ strain_limit = self._paper.material.max_strain if self._paper else 0.05
208
+ if max_strain > strain_limit:
209
+ reward -= 3.0 * (max_strain / strain_limit)
210
+
211
+ return float(reward)
server/tasks.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task pool and curriculum for the origami RL environment.
3
+
4
+ 7 tasks across 4 difficulty levels.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ from typing import Optional
10
+
11
+
12
+ TASKS: dict[str, dict] = {
13
+ "half_fold": {
14
+ "name": "half_fold",
15
+ "description": "Fold a 1x1 paper sheet in half along the horizontal midline.",
16
+ "width": 1.0,
17
+ "height": 1.0,
18
+ "material": "paper",
19
+ "target_ratio": 0.50,
20
+ "max_folds": 3,
21
+ "target_box": [1.0, 0.5, 0.02],
22
+ "must_deploy": False,
23
+ "difficulty": 1,
24
+ },
25
+ "quarter_fold": {
26
+ "name": "quarter_fold",
27
+ "description": "Fold a 1x1 paper sheet into quarters using two perpendicular folds.",
28
+ "width": 1.0,
29
+ "height": 1.0,
30
+ "material": "paper",
31
+ "target_ratio": 0.25,
32
+ "max_folds": 5,
33
+ "target_box": [0.5, 0.5, 0.04],
34
+ "must_deploy": False,
35
+ "difficulty": 1,
36
+ },
37
+ "letter_fold": {
38
+ "name": "letter_fold",
39
+ "description": "Fold a 1x1 paper into thirds (letter fold) using two parallel folds.",
40
+ "width": 1.0,
41
+ "height": 1.0,
42
+ "material": "paper",
43
+ "target_ratio": 0.33,
44
+ "max_folds": 5,
45
+ "target_box": [1.0, 0.34, 0.03],
46
+ "must_deploy": False,
47
+ "difficulty": 2,
48
+ },
49
+ "map_fold": {
50
+ "name": "map_fold",
51
+ "description": "Fold a 1x1 paper into eighths using a grid fold pattern. Must be re-deployable.",
52
+ "width": 1.0,
53
+ "height": 1.0,
54
+ "material": "paper",
55
+ "target_ratio": 0.125,
56
+ "max_folds": 8,
57
+ "target_box": [0.5, 0.25, 0.08],
58
+ "must_deploy": True,
59
+ "difficulty": 2,
60
+ },
61
+ "solar_panel": {
62
+ "name": "solar_panel",
63
+ "description": "Pack a 1x1 Mylar solar panel into a compact configuration using a Miura-ori style fold. Must deploy.",
64
+ "width": 1.0,
65
+ "height": 1.0,
66
+ "material": "mylar",
67
+ "target_ratio": 0.05,
68
+ "max_folds": 20,
69
+ "target_box": [0.25, 0.25, 0.05],
70
+ "must_deploy": True,
71
+ "difficulty": 3,
72
+ },
73
+ "shelter_wall": {
74
+ "name": "shelter_wall",
75
+ "description": "Fold a 1x1 aluminum sheet into a compact structural panel within strain limits.",
76
+ "width": 1.0,
77
+ "height": 1.0,
78
+ "material": "aluminum",
79
+ "target_ratio": 0.10,
80
+ "max_folds": 15,
81
+ "target_box": [0.5, 0.25, 0.1],
82
+ "must_deploy": False,
83
+ "difficulty": 3,
84
+ },
85
+ "stent": {
86
+ "name": "stent",
87
+ "description": "Fold a 0.5x1.5 nitinol sheet into a compact tube configuration for a medical stent. Superelastic material.",
88
+ "width": 0.5,
89
+ "height": 1.5,
90
+ "material": "nitinol",
91
+ "target_ratio": 0.09,
92
+ "max_folds": 25,
93
+ "target_box": [0.1, 0.1, 0.15],
94
+ "must_deploy": True,
95
+ "difficulty": 4,
96
+ },
97
+ }
98
+
99
+
100
+ def get_task_by_name(name: str) -> Optional[dict]:
101
+ """Return task dict by name, or None if not found."""
102
+ return TASKS.get(name)
103
+
104
+
105
+ def sample_task(seed: Optional[int] = None, difficulty: Optional[int] = None) -> dict:
106
+ """Sample a random task, optionally filtered by difficulty level."""
107
+ rng = random.Random(seed)
108
+ pool = list(TASKS.values())
109
+ if difficulty is not None:
110
+ pool = [t for t in pool if t["difficulty"] == difficulty]
111
+ if not pool:
112
+ pool = list(TASKS.values())
113
+ return dict(rng.choice(pool))
114
+
115
+
116
+ def get_tasks_by_difficulty(level: int) -> list[dict]:
117
+ """Return all tasks at a given difficulty level."""
118
+ return [dict(t) for t in TASKS.values() if t["difficulty"] == level]
119
+
120
+
121
+ def available_task_names() -> list[str]:
122
+ """Return sorted list of all task names."""
123
+ return sorted(TASKS.keys())
server/training_broadcast.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TrainingBroadcastServer β€” fire-and-forget broadcast hub for live training viewer.
3
+
4
+ The RL training process calls publish() after each env.step().
5
+ Spectator browsers connect via /ws/training WebSocket.
6
+ Broadcast is async and non-blocking: if no viewers are connected, observations are dropped.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Optional
15
+
16
+ from fastapi import WebSocket, WebSocketDisconnect
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class EpisodeInfo:
23
+ episode_id: str
24
+ task_name: str
25
+ status: str = "running" # "running" | "done" | "timeout" | "error"
26
+ step: int = 0
27
+ observation: dict = field(default_factory=dict)
28
+ metrics: dict = field(default_factory=dict)
29
+ fold_history: list = field(default_factory=list)
30
+ score: Optional[float] = None
31
+ final_metrics: Optional[dict] = None
32
+
33
+
34
+ class TrainingBroadcastServer:
35
+ """Central hub for broadcasting RL training observations to spectator WebSockets.
36
+
37
+ Thread-safe: publish() can be called from training threads (ThreadPoolExecutor).
38
+ WebSocket handlers run in the asyncio event loop.
39
+ """
40
+
41
+ def __init__(self) -> None:
42
+ self._spectators: list[WebSocket] = []
43
+ self._registry: dict[str, EpisodeInfo] = {}
44
+ self._batch_id: int = 0
45
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
46
+ self._lock = asyncio.Lock()
47
+
48
+ # ── Episode publishing (called from training thread / async context) ──
49
+
50
+ def publish(self, episode_id: str, data: dict) -> None:
51
+ """Fire-and-forget: push an update from the training process.
52
+
53
+ Safe to call from any thread. If no event loop is running, logs and returns.
54
+ """
55
+ try:
56
+ loop = asyncio.get_event_loop()
57
+ if loop.is_running():
58
+ asyncio.ensure_future(self._async_publish(episode_id, data), loop=loop)
59
+ else:
60
+ loop.run_until_complete(self._async_publish(episode_id, data))
61
+ except RuntimeError:
62
+ # No event loop β€” training without server
63
+ pass
64
+
65
+ async def _async_publish(self, episode_id: str, data: dict) -> None:
66
+ msg_type = data.get("type", "episode_update")
67
+
68
+ async with self._lock:
69
+ if msg_type == "batch_start":
70
+ self._batch_id = data.get("batch_id", self._batch_id + 1)
71
+ self._registry.clear()
72
+ await self._broadcast(data)
73
+ return
74
+
75
+ if msg_type == "batch_done":
76
+ await self._broadcast(data)
77
+ return
78
+
79
+ if msg_type == "training_done":
80
+ await self._broadcast(data)
81
+ return
82
+
83
+ # episode_update or episode_done
84
+ ep = self._registry.setdefault(
85
+ episode_id,
86
+ EpisodeInfo(episode_id=episode_id, task_name=data.get("task_name", "")),
87
+ )
88
+
89
+ if msg_type == "episode_done":
90
+ ep.status = data.get("status", "done")
91
+ ep.score = data.get("score")
92
+ ep.final_metrics = data.get("final_metrics")
93
+ else:
94
+ ep.step = data.get("step", ep.step)
95
+ ep.status = "running"
96
+ obs = data.get("observation", {})
97
+ ep.observation = obs
98
+ ep.metrics = obs.get("metrics", {})
99
+ ep.fold_history = obs.get("fold_history", ep.fold_history)
100
+
101
+ await self._broadcast({"episode_id": episode_id, **data})
102
+
103
+ # ── Spectator management ──
104
+
105
+ async def connect_spectator(self, websocket: WebSocket) -> None:
106
+ """Accept a new viewer WebSocket and serve it until disconnect."""
107
+ await websocket.accept()
108
+
109
+ async with self._lock:
110
+ self._spectators.append(websocket)
111
+
112
+ # Send current registry snapshot immediately
113
+ await self._send_registry(websocket)
114
+
115
+ try:
116
+ while True:
117
+ # Viewers are read-only; drain any incoming messages (pings etc)
118
+ await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
119
+ except (WebSocketDisconnect, asyncio.TimeoutError, Exception):
120
+ pass
121
+ finally:
122
+ await self.disconnect_spectator(websocket)
123
+
124
+ async def disconnect_spectator(self, websocket: WebSocket) -> None:
125
+ async with self._lock:
126
+ self._spectators = [s for s in self._spectators if s is not websocket]
127
+
128
+ # ── Batch control ──
129
+
130
+ async def start_batch(self, batch_id: int, num_episodes: int, prompt_index: int = 0) -> None:
131
+ """Call before starting a new training batch."""
132
+ data = {
133
+ "type": "batch_start",
134
+ "batch_id": batch_id,
135
+ "num_episodes": num_episodes,
136
+ "prompt_index": prompt_index,
137
+ }
138
+ await self._async_publish("__batch__", data)
139
+
140
+ async def finish_batch(
141
+ self,
142
+ batch_id: int,
143
+ scores: list[float],
144
+ best_episode_id: str = "",
145
+ ) -> None:
146
+ """Call after all episodes in a batch complete."""
147
+ data = {
148
+ "type": "batch_done",
149
+ "batch_id": batch_id,
150
+ "scores": scores,
151
+ "best_episode_id": best_episode_id,
152
+ "avg_score": sum(scores) / len(scores) if scores else 0.0,
153
+ }
154
+ await self._async_publish("__batch__", data)
155
+
156
+ async def clear_batch(self) -> None:
157
+ """Reset episode registry for next batch."""
158
+ async with self._lock:
159
+ self._registry.clear()
160
+
161
+ # ── Internals ──
162
+
163
+ async def _broadcast(self, message: dict) -> None:
164
+ """Send message to all spectators, removing dead connections."""
165
+ if not self._spectators:
166
+ return
167
+ payload = json.dumps(message, default=str)
168
+ dead: list[WebSocket] = []
169
+ for ws in list(self._spectators):
170
+ try:
171
+ await ws.send_text(payload)
172
+ except Exception:
173
+ dead.append(ws)
174
+ for ws in dead:
175
+ self._spectators = [s for s in self._spectators if s is not ws]
176
+
177
+ async def _send_registry(self, websocket: WebSocket) -> None:
178
+ """Send the full episode registry to a newly connected viewer."""
179
+ async with self._lock:
180
+ episodes = {
181
+ ep_id: {
182
+ "status": ep.status,
183
+ "task": ep.task_name,
184
+ "step": ep.step,
185
+ "observation": ep.observation,
186
+ "metrics": ep.metrics,
187
+ "score": ep.score,
188
+ }
189
+ for ep_id, ep in self._registry.items()
190
+ }
191
+ payload = {
192
+ "type": "registry",
193
+ "batch_id": self._batch_id,
194
+ "episodes": episodes,
195
+ }
196
+ try:
197
+ await websocket.send_text(json.dumps(payload, default=str))
198
+ except Exception:
199
+ pass
200
+
201
+ @property
202
+ def spectator_count(self) -> int:
203
+ return len(self._spectators)
204
+
205
+ @property
206
+ def active_episodes(self) -> int:
207
+ return sum(1 for ep in self._registry.values() if ep.status == "running")
server.py β†’ server_legacy.py RENAMED
File without changes