ianalin123 commited on
Commit
4c6824f
·
1 Parent(s): 0153179

Merge origin/main into pr/6 — keep pr/6 refactor (openenv, env, no legacy engine)

Browse files
openenv_server/app.py CHANGED
@@ -160,7 +160,6 @@ def _graph_state_to_fold(paper_dict: dict) -> dict:
160
  edges_assignment.append(asgn)
161
 
162
  faces_vertices = _triangulate_vertices(vertices_coords)
163
-
164
  return {
165
  "vertices_coords": vertices_coords,
166
  "edges_vertices": edges_vertices,
 
160
  edges_assignment.append(asgn)
161
 
162
  faces_vertices = _triangulate_vertices(vertices_coords)
 
163
  return {
164
  "vertices_coords": vertices_coords,
165
  "edges_vertices": edges_vertices,
server/app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/app.py — Training WebSocket server for Colab environment.
3
+
4
+ Provides /ws/training for live streaming of RL training episodes to browsers.
5
+ Mount at a publicly accessible URL in Colab (e.g., via ngrok or Colab's proxy).
6
+
7
+ Usage in training:
8
+ from server.app import broadcast
9
+ broadcast.publish(episode_id, {"type": "episode_update", ...})
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException, WebSocket
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import HTMLResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+
21
+ from server.training_broadcast import TrainingBroadcastServer
22
+
23
+ app = FastAPI(title="Optigami Training Server", version="1.0")
24
+
25
+ # Allow cross-origin connections (Colab public URL → browser)
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Global broadcast server — import and use from training code
35
+ broadcast = TrainingBroadcastServer()
36
+
37
+
38
+ @app.on_event("startup")
39
+ async def _store_loop() -> None:
40
+ """Capture the asyncio event loop so training threads can schedule coroutines."""
41
+ import asyncio
42
+ broadcast._loop = asyncio.get_running_loop()
43
+
44
+
45
+ @app.websocket("/ws/training")
46
+ async def training_ws(websocket: WebSocket) -> None:
47
+ """Spectator WebSocket endpoint. Viewers connect here to watch training."""
48
+ await broadcast.connect_spectator(websocket)
49
+
50
+
51
+ @app.get("/health")
52
+ def health() -> dict:
53
+ return {
54
+ "status": "ok",
55
+ "spectators": broadcast.spectator_count,
56
+ "active_episodes": broadcast.active_episodes,
57
+ }
58
+
59
+
60
+ # ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
61
+
62
+ @app.get("/targets")
63
+ def get_targets() -> dict:
64
+ from server.tasks import available_task_names, get_task_by_name
65
+ return {
66
+ name: {
67
+ "name": name,
68
+ "level": t["difficulty"],
69
+ "description": t.get("description", ""),
70
+ "n_creases": t.get("max_folds", 3),
71
+ "difficulty": t["difficulty"],
72
+ "material": t.get("material", "paper"),
73
+ }
74
+ for name in available_task_names()
75
+ if (t := get_task_by_name(name))
76
+ }
77
+
78
+
79
+ _DEMO_SEQUENCES: dict[str, list[dict]] = {
80
+ "half_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}],
81
+ "quarter_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
82
+ {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
83
+ "letter_fold": [{"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
84
+ {"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0}],
85
+ "map_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
86
+ {"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
87
+ "solar_panel": [{"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
88
+ {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
89
+ {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0}],
90
+ }
91
+
92
+
93
+ @app.get("/episode/demo")
94
+ def demo_episode(target: str = "half_fold") -> dict:
95
+ from server.origami_environment import OrigamiEnvironment
96
+ from server.models import OrigamiAction as NewAction
97
+ from server.tasks import get_task_by_name
98
+
99
+ folds = _DEMO_SEQUENCES.get(target, _DEMO_SEQUENCES["half_fold"])
100
+ env = OrigamiEnvironment()
101
+ obs = env.reset(task_name=target)
102
+ steps: list[dict] = []
103
+
104
+ for i, fold_dict in enumerate(folds):
105
+ action = NewAction(
106
+ fold_type=fold_dict["type"],
107
+ fold_line=fold_dict["line"],
108
+ fold_angle=float(fold_dict.get("angle", 180.0)),
109
+ )
110
+ obs = env.step(action)
111
+ steps.append({"step": i + 1, "fold": fold_dict,
112
+ "paper_state": obs.paper_state, "metrics": obs.metrics,
113
+ "done": obs.done})
114
+ if obs.done:
115
+ break
116
+
117
+ return {"task_name": target, "task": get_task_by_name(target) or {},
118
+ "steps": steps, "final_metrics": obs.metrics if steps else {}}
119
+
120
+
121
+ @app.get("/episode/replay/{ep_id}")
122
+ def replay_episode(ep_id: str) -> dict:
123
+ """Return a stored training episode in the same format as /episode/demo."""
124
+ from server.tasks import get_task_by_name
125
+ ep = broadcast._registry.get(ep_id)
126
+ if not ep:
127
+ raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
128
+ return {
129
+ "task_name": ep.task_name,
130
+ "task": get_task_by_name(ep.task_name) or {},
131
+ "steps": ep.steps,
132
+ "final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
133
+ }
134
+
135
+
136
+ # ── Static files — viewer first, then React app (LAST, catch-all) ──
137
+
138
+ _VIEWER_DIR = Path(__file__).resolve().parent.parent / "viewer"
139
+ _BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
140
+
141
+ if _VIEWER_DIR.exists():
142
+ app.mount("/viewer", StaticFiles(directory=str(_VIEWER_DIR), html=True), name="viewer")
143
+
144
+
145
+ if _BUILD_DIR.exists():
146
+ app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="react")
147
+ else:
148
+ @app.get("/", include_in_schema=False)
149
+ def _no_build() -> HTMLResponse:
150
+ return HTMLResponse(
151
+ "<p>React build not found. Run <code>npm run build</code> in the frontend directory.</p>"
152
+ "<p>Training viewer: <a href='/viewer/training.html'>/viewer/training.html</a></p>"
153
+ )
154
+
155
+
156
+ def run(host: str = "0.0.0.0", port: int = 9001) -> None:
157
+ """Start the training server. Call from Colab notebook."""
158
+ uvicorn.run(app, host=host, port=port)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ run()
server/models.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv Pydantic models for the origami RL environment.
3
+
4
+ OrigamiAction — one fold per step
5
+ OrigamiObservation — everything the LLM and Three.js viewer need
6
+ OrigamiState — server-side episode tracking
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+ # openenv base classes — use them if available, fall back to plain Pydantic
15
+ try:
16
+ from openenv.core.env_server.types import Action, Observation, State
17
+ except ImportError:
18
+ Action = BaseModel
19
+ class State(BaseModel):
20
+ """Minimal stand-in for openenv State base class."""
21
+ episode_id: Optional[str] = None
22
+ step_count: int = 0
23
+
24
+ class Observation(BaseModel):
25
+ """Minimal stand-in for openenv Observation base class."""
26
+ done: bool = False
27
+ reward: Optional[float] = None
28
+
29
+
30
+ class OrigamiAction(Action):
31
+ """One fold operation sent by the client each step."""
32
+
33
+ fold_type: str = Field(
34
+ default="valley",
35
+ description="'valley' | 'mountain' | 'pleat' | 'crimp' | 'stop'",
36
+ )
37
+ fold_line: dict[str, list[float]] = Field(
38
+ default_factory=lambda: {"start": [0.0, 0.5], "end": [1.0, 0.5]},
39
+ description="{'start': [x, y], 'end': [x, y]} normalized 0-1",
40
+ )
41
+ fold_angle: float = Field(
42
+ default=180.0,
43
+ description="Fold angle in degrees, 0-180",
44
+ )
45
+ layer_select: str = Field(
46
+ default="all",
47
+ description="'all' | 'top' | 'bottom'",
48
+ )
49
+
50
+
51
+ class OrigamiObservation(Observation):
52
+ """Everything the LLM and Three.js viewer need.
53
+
54
+ paper_state contains FOLD-compatible geometry + physics data.
55
+ metrics contains all computed quality metrics.
56
+ No render_urls — the browser renders from paper_state directly.
57
+ """
58
+
59
+ task: dict[str, Any] = Field(default_factory=dict)
60
+ paper_state: dict[str, Any] = Field(default_factory=dict)
61
+ metrics: dict[str, Any] = Field(default_factory=dict)
62
+ fold_history: list[dict[str, Any]] = Field(default_factory=list)
63
+ error: Optional[str] = Field(default=None)
64
+
65
+
66
+ class OrigamiState(State):
67
+ """Server-side episode tracking."""
68
+
69
+ task_name: str = Field(default="")
70
+ num_folds_applied: int = Field(default=0)
71
+ is_valid: bool = Field(default=True)
72
+ total_reward: float = Field(default=0.0)
server/origami_environment.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OrigamiEnvironment — OpenEnv environment wrapping the origami physics engine.
3
+
4
+ Implements reset() / step() / state following the OpenEnv interface.
5
+ Engine (physics, fold, validation, metrics) lives in engine/.
6
+ No server-side image rendering — paper_state contains all geometry data.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import uuid
13
+ from typing import Any, Optional
14
+
15
+ # openenv base class — fall back to plain object if not installed
16
+ try:
17
+ from openenv.core.env_server.interfaces import Environment
18
+ except ImportError:
19
+ from typing import Generic, TypeVar
20
+ A = TypeVar("A")
21
+ O = TypeVar("O")
22
+ S = TypeVar("S")
23
+ class Environment(Generic[A, O, S]):
24
+ """Minimal stand-in for openenv.core.env_server.interfaces.Environment."""
25
+ def __init__(self, **kwargs): pass
26
+
27
+ from engine.paper import Paper
28
+ from engine.fold_engine import apply_fold
29
+ from engine.physics import simulate
30
+ from engine.validation import validate_state
31
+ from engine.metrics import compute_all_metrics
32
+ from server.models import OrigamiAction, OrigamiObservation, OrigamiState
33
+ from server.tasks import get_task_by_name, sample_task
34
+
35
+
36
+ def _get_material(name: str):
37
+ """Get material by name, falling back to paper."""
38
+ try:
39
+ from engine.materials import get_material
40
+ return get_material(name)
41
+ except Exception:
42
+ from engine.materials import get_material
43
+ return get_material("paper")
44
+
45
+
46
+ class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
47
+ """Origami folding RL environment.
48
+
49
+ Each episode: agent receives paper_state + task, applies folds one at a
50
+ time via step(), receives metrics + reward, ends with 'stop' action or
51
+ when max_folds is reached.
52
+ """
53
+
54
+ SUPPORTS_CONCURRENT_SESSIONS = False
55
+
56
+ def __init__(self, **kwargs):
57
+ super().__init__(**kwargs)
58
+ self._paper: Optional[Paper] = None
59
+ self._task: Optional[dict] = None
60
+ self._fold_history: list[dict] = []
61
+ self._metrics: dict = {}
62
+ self._validation: dict = {}
63
+ self._error: Optional[str] = None
64
+ self._episode_id: Optional[str] = None
65
+ self._step_count: int = 0
66
+ self._total_reward: float = 0.0
67
+
68
+ # ── reset ─────────────────────────────────────────────────────────
69
+
70
+ def reset(
71
+ self,
72
+ seed: Optional[int] = None,
73
+ episode_id: Optional[str] = None,
74
+ **kwargs: Any,
75
+ ) -> OrigamiObservation:
76
+ self._episode_id = episode_id or str(uuid.uuid4())
77
+ self._step_count = 0
78
+ self._fold_history = []
79
+ self._error = None
80
+ self._total_reward = 0.0
81
+
82
+ # Select task
83
+ task_name = kwargs.get("task_name")
84
+ if task_name:
85
+ self._task = get_task_by_name(task_name)
86
+ if not self._task:
87
+ self._task = sample_task(seed=seed)
88
+
89
+ # Create flat sheet
90
+ mat = _get_material(self._task["material"])
91
+ self._paper = Paper.create_flat_sheet(
92
+ width=self._task["width"],
93
+ height=self._task["height"],
94
+ material=mat,
95
+ )
96
+
97
+ # Initial validation + metrics (no physics needed for flat sheet)
98
+ self._validation = validate_state(self._paper)
99
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
100
+
101
+ return self._make_observation(done=False, reward=None)
102
+
103
+ # ── step ──────────────────────────────────────────────────────────
104
+
105
+ def step(
106
+ self,
107
+ action: OrigamiAction,
108
+ timeout_s: Optional[float] = None,
109
+ **kwargs: Any,
110
+ ) -> OrigamiObservation:
111
+ if self._paper is None or self._task is None:
112
+ return self._make_observation(done=True, reward=-5.0)
113
+
114
+ self._step_count += 1
115
+ self._error = None
116
+
117
+ # ── Stop action ───────────────────────────────────────────────
118
+ if action.fold_type == "stop":
119
+ return self._finalize_episode()
120
+
121
+ # ── Build fold dict ───────────────────────────────────────────
122
+ fold_dict = {
123
+ "type": action.fold_type,
124
+ "line": action.fold_line,
125
+ "angle": action.fold_angle,
126
+ }
127
+
128
+ # ── Apply fold ────────────────────────────────────────────────
129
+ new_paper, err = apply_fold(self._paper, fold_dict)
130
+ if err:
131
+ self._error = err
132
+ return self._make_observation(done=True, reward=-5.0)
133
+
134
+ self._paper = new_paper
135
+ self._fold_history.append({**fold_dict, "step": self._step_count})
136
+
137
+ # ── Physics relaxation ────────────────────────────────────────
138
+ try:
139
+ self._paper = simulate(self._paper, fold_percent=1.0)
140
+ except Exception as exc:
141
+ self._error = f"Physics failed: {exc}"
142
+ # Continue — don't abort episode on physics failure
143
+
144
+ # ── Validate ──────────────────────────────────────────────────
145
+ self._validation = validate_state(self._paper)
146
+
147
+ # ── Metrics ───────────────────────────────────────────────────
148
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
149
+
150
+ # ── Check termination ─────────────────────────────────────────
151
+ max_folds = self._task.get("max_folds", 50)
152
+ if self._step_count >= max_folds:
153
+ return self._finalize_episode()
154
+
155
+ if self._validation.get("self_intersections", 0) > 0:
156
+ self._error = "Self-intersection detected"
157
+ return self._finalize_episode()
158
+
159
+ return self._make_observation(done=False, reward=None)
160
+
161
+ # ── state ─────────────────────────────────────────────────────────
162
+
163
+ @property
164
+ def state(self) -> OrigamiState:
165
+ return OrigamiState(
166
+ episode_id=self._episode_id,
167
+ step_count=self._step_count,
168
+ task_name=self._task.get("name", "") if self._task else "",
169
+ num_folds_applied=len(self._fold_history),
170
+ is_valid=self._metrics.get("is_valid", True),
171
+ total_reward=self._total_reward,
172
+ )
173
+
174
+ # ── internals ─────────────────────────────────────────────────────
175
+
176
+ def _finalize_episode(self) -> OrigamiObservation:
177
+ reward = self._compute_reward()
178
+ self._total_reward = reward
179
+ return self._make_observation(done=True, reward=reward)
180
+
181
+ def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
182
+ return OrigamiObservation(
183
+ done=done,
184
+ reward=reward,
185
+ task=self._task or {},
186
+ paper_state=self._paper.to_observation_dict() if self._paper else {},
187
+ metrics=self._metrics,
188
+ fold_history=self._fold_history,
189
+ error=self._error,
190
+ )
191
+
192
+ def _compute_reward(self) -> float:
193
+ m = self._metrics
194
+ reward = 0.0
195
+
196
+ # Compactness is the main signal
197
+ reward += m.get("compactness", 0.0) * 20.0
198
+
199
+ # Bonus for fitting in target box
200
+ if m.get("fits_target_box", False):
201
+ reward += 10.0
202
+
203
+ # Bonus for deployability (if task requires it)
204
+ if m.get("is_deployable", False):
205
+ reward += 5.0
206
+
207
+ # Penalties for violations
208
+ reward -= m.get("kawasaki_violations", 0) * 2.0
209
+ reward -= m.get("maekawa_violations", 0) * 2.0
210
+ reward -= m.get("self_intersections", 0) * 5.0
211
+
212
+ # Penalty for too many folds (encourage efficiency)
213
+ reward -= m.get("fold_count", 0) * 0.5
214
+
215
+ # Penalty for exceeding material strain limit
216
+ max_strain = m.get("max_strain", 0.0)
217
+ strain_limit = self._paper.material.max_strain if self._paper else 0.05
218
+ if max_strain > strain_limit:
219
+ reward -= 3.0 * (max_strain / strain_limit)
220
+
221
+ return float(reward)
server/tasks.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task pool and curriculum for the origami RL environment.
3
+
4
+ 7 tasks across 4 difficulty levels.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ from typing import Optional
10
+
11
+
12
+ TASKS: dict[str, dict] = {
13
+ "half_fold": {
14
+ "name": "half_fold",
15
+ "description": "Fold a 1x1 paper sheet in half along the horizontal midline.",
16
+ "width": 1.0,
17
+ "height": 1.0,
18
+ "material": "paper",
19
+ "target_ratio": 0.50,
20
+ "max_folds": 3,
21
+ "target_box": [1.0, 0.5, 0.02],
22
+ "must_deploy": False,
23
+ "difficulty": 1,
24
+ },
25
+ "quarter_fold": {
26
+ "name": "quarter_fold",
27
+ "description": "Fold a 1x1 paper sheet into quarters using two perpendicular folds.",
28
+ "width": 1.0,
29
+ "height": 1.0,
30
+ "material": "paper",
31
+ "target_ratio": 0.25,
32
+ "max_folds": 5,
33
+ "target_box": [0.5, 0.5, 0.04],
34
+ "must_deploy": False,
35
+ "difficulty": 1,
36
+ },
37
+ "letter_fold": {
38
+ "name": "letter_fold",
39
+ "description": "Fold a 1x1 paper into thirds (letter fold) using two parallel folds.",
40
+ "width": 1.0,
41
+ "height": 1.0,
42
+ "material": "paper",
43
+ "target_ratio": 0.33,
44
+ "max_folds": 5,
45
+ "target_box": [1.0, 0.34, 0.03],
46
+ "must_deploy": False,
47
+ "difficulty": 2,
48
+ },
49
+ "map_fold": {
50
+ "name": "map_fold",
51
+ "description": "Fold a 1x1 paper into eighths using a grid fold pattern. Must be re-deployable.",
52
+ "width": 1.0,
53
+ "height": 1.0,
54
+ "material": "paper",
55
+ "target_ratio": 0.125,
56
+ "max_folds": 8,
57
+ "target_box": [0.5, 0.25, 0.08],
58
+ "must_deploy": True,
59
+ "difficulty": 2,
60
+ },
61
+ "solar_panel": {
62
+ "name": "solar_panel",
63
+ "description": "Pack a 1x1 Mylar solar panel into a compact configuration using a Miura-ori style fold. Must deploy.",
64
+ "width": 1.0,
65
+ "height": 1.0,
66
+ "material": "mylar",
67
+ "target_ratio": 0.05,
68
+ "max_folds": 20,
69
+ "target_box": [0.25, 0.25, 0.05],
70
+ "must_deploy": True,
71
+ "difficulty": 3,
72
+ },
73
+ "shelter_wall": {
74
+ "name": "shelter_wall",
75
+ "description": "Fold a 1x1 aluminum sheet into a compact structural panel within strain limits.",
76
+ "width": 1.0,
77
+ "height": 1.0,
78
+ "material": "aluminum",
79
+ "target_ratio": 0.10,
80
+ "max_folds": 15,
81
+ "target_box": [0.5, 0.25, 0.1],
82
+ "must_deploy": False,
83
+ "difficulty": 3,
84
+ },
85
+ "stent": {
86
+ "name": "stent",
87
+ "description": "Fold a 0.5x1.5 nitinol sheet into a compact tube configuration for a medical stent. Superelastic material.",
88
+ "width": 0.5,
89
+ "height": 1.5,
90
+ "material": "nitinol",
91
+ "target_ratio": 0.09,
92
+ "max_folds": 25,
93
+ "target_box": [0.1, 0.1, 0.15],
94
+ "must_deploy": True,
95
+ "difficulty": 4,
96
+ },
97
+ }
98
+
99
+
100
+ def get_task_by_name(name: str) -> Optional[dict]:
101
+ """Return task dict by name, or None if not found."""
102
+ return TASKS.get(name)
103
+
104
+
105
+ def sample_task(seed: Optional[int] = None, difficulty: Optional[int] = None) -> dict:
106
+ """Sample a random task, optionally filtered by difficulty level."""
107
+ rng = random.Random(seed)
108
+ pool = list(TASKS.values())
109
+ if difficulty is not None:
110
+ pool = [t for t in pool if t["difficulty"] == difficulty]
111
+ if not pool:
112
+ pool = list(TASKS.values())
113
+ return dict(rng.choice(pool))
114
+
115
+
116
+ def get_tasks_by_difficulty(level: int) -> list[dict]:
117
+ """Return all tasks at a given difficulty level."""
118
+ return [dict(t) for t in TASKS.values() if t["difficulty"] == level]
119
+
120
+
121
+ def available_task_names() -> list[str]:
122
+ """Return sorted list of all task names."""
123
+ return sorted(TASKS.keys())
src/App.js CHANGED
@@ -16,7 +16,7 @@ const REPLAY_EP_ID = _urlParams.get('ep') || null;
16
 
17
  function App() {
18
  const [targets, setTargets] = useState({});
19
- const [selectedTarget, setSelectedTarget] = useState('half_horizontal');
20
  const [episode, setEpisode] = useState(null);
21
  const [currentStep, setCurrentStep] = useState(0);
22
  const [playing, setPlaying] = useState(false);
 
16
 
17
  function App() {
18
  const [targets, setTargets] = useState({});
19
+ const [selectedTarget, setSelectedTarget] = useState('half_fold');
20
  const [episode, setEpisode] = useState(null);
21
  const [currentStep, setCurrentStep] = useState(0);
22
  const [playing, setPlaying] = useState(false);
src/components/Fold3DCanvas.js CHANGED
@@ -7,10 +7,8 @@ const PITCH_MAX = Math.PI / 2 - 0.1;
7
  const ZOOM_MIN = 0.3;
8
  const ZOOM_MAX = 5.0;
9
  const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
10
- const MAX_FOLD_RAD = Math.PI * 0.92;
11
- const SIDE_EPS = 1e-7;
12
- const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.95)';
13
- const VALLEY_COLOR = 'rgba(56, 189, 248, 0.95)';
14
 
15
  function clamp(value, min, max) {
16
  return Math.min(Math.max(value, min), max);
@@ -46,6 +44,9 @@ function shadePaper(intensity) {
46
  return `rgb(${r}, ${g}, ${b})`;
47
  }
48
 
 
 
 
49
  function buildGridMesh(resolution = 18) {
50
  const vertices = [];
51
  for (let y = 0; y <= resolution; y += 1) {
@@ -170,7 +171,7 @@ function applyAllFolds(vertices, foldMasks, progresses) {
170
  function projectVertex(vertex, dim, pitch, yaw, zoom) {
171
  let x = vertex[0] - 0.5;
172
  let y = vertex[1] - 0.5;
173
- let z = vertex[2];
174
 
175
  const cp = Math.cos(pitch);
176
  const sp = Math.sin(pitch);
 
7
  const ZOOM_MIN = 0.3;
8
  const ZOOM_MAX = 5.0;
9
  const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
10
+ const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.9)';
11
+ const VALLEY_COLOR = 'rgba(56, 189, 248, 0.9)';
 
 
12
 
13
  function clamp(value, min, max) {
14
  return Math.min(Math.max(value, min), max);
 
44
  return `rgb(${r}, ${g}, ${b})`;
45
  }
46
 
47
+ const SIDE_EPS = 1e-10;
48
+ const MAX_FOLD_RAD = Math.PI;
49
+
50
  function buildGridMesh(resolution = 18) {
51
  const vertices = [];
52
  for (let y = 0; y <= resolution; y += 1) {
 
171
  function projectVertex(vertex, dim, pitch, yaw, zoom) {
172
  let x = vertex[0] - 0.5;
173
  let y = vertex[1] - 0.5;
174
+ let z = vertex[2] || 0;
175
 
176
  const cp = Math.cos(pitch);
177
  const sp = Math.sin(pitch);
training/__init__.py ADDED
File without changes
training/demo.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training/demo.py — Run 8 zero-shot rollouts and stream them to the grid viewer.
3
+
4
+ Usage:
5
+ cd /path/to/optigami
6
+ python -m training.demo
7
+
8
+ Then open: http://localhost:9001/viewer/training.html
9
+
10
+ Each of the 8 "strategies" is a heuristic that mimics what a pretrained LLM might
11
+ produce for different tasks — varying from near-optimal to poor. This exercises
12
+ the full broadcast → grid viewer pipeline without requiring an LLM API key.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import time
18
+ import uuid
19
+ from typing import Callable
20
+
21
+ import uvicorn
22
+
23
+ from server.app import app, broadcast
24
+ from training.runner import run_batch
25
+
26
+
27
+ # ── 8 zero-shot heuristic strategies ──────────────────────────────────────────
28
+ # Each is a callable: paper_state (dict) → fold_dict
29
+ # These represent the range of strategies a pretrained LLM might generate.
30
+
31
+ def strategy_perfect_half(paper_state: dict) -> dict:
32
+ """Valley fold exactly at horizontal midline — optimal for half_fold."""
33
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
34
+
35
+
36
+ def strategy_slight_offset(paper_state: dict) -> dict:
37
+ """Valley fold slightly off-center — almost optimal."""
38
+ return {"type": "valley", "line": {"start": [0.0, 0.48], "end": [1.0, 0.48]}, "angle": 180.0}
39
+
40
+
41
+ def strategy_thirds(paper_state: dict) -> dict:
42
+ """Letter fold at one-third — wrong for half_fold, generates interesting geometry."""
43
+ fold_count = paper_state.get("fold_count", 0)
44
+ positions = [0.333, 0.667]
45
+ if fold_count >= len(positions):
46
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
47
+ return {
48
+ "type": "valley" if fold_count == 0 else "mountain",
49
+ "line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
50
+ "angle": 180.0,
51
+ }
52
+
53
+
54
+ def strategy_vertical(paper_state: dict) -> dict:
55
+ """Vertical fold — gets compactness but in wrong dimension for target_box."""
56
+ return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
57
+
58
+
59
+ def strategy_mountain(paper_state: dict) -> dict:
60
+ """Mountain fold at midline — same geometry, different assignment."""
61
+ return {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
62
+
63
+
64
+ def strategy_accordion(paper_state: dict) -> dict:
65
+ """Accordion 3-fold — overfolds, achieves high compactness but more folds."""
66
+ fold_count = paper_state.get("fold_count", 0)
67
+ positions = [0.25, 0.5, 0.75]
68
+ assignments = ["valley", "mountain", "valley"]
69
+ if fold_count >= len(positions):
70
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
71
+ return {
72
+ "type": assignments[fold_count],
73
+ "line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
74
+ "angle": 180.0,
75
+ }
76
+
77
+
78
+ def strategy_diagonal(paper_state: dict) -> dict:
79
+ """Diagonal fold — achieves compactness but irregular bounding box."""
80
+ return {"type": "valley", "line": {"start": [0.0, 0.0], "end": [1.0, 1.0]}, "angle": 180.0}
81
+
82
+
83
+ def strategy_quarter(paper_state: dict) -> dict:
84
+ """Two perpendicular folds — 4x compactness for quarter_fold task."""
85
+ fold_count = paper_state.get("fold_count", 0)
86
+ if fold_count == 0:
87
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
88
+ if fold_count == 1:
89
+ return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
90
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
91
+
92
+
93
+ STRATEGIES: list[tuple[str, Callable]] = [
94
+ ("perfect_half", strategy_perfect_half),
95
+ ("slight_offset", strategy_slight_offset),
96
+ ("thirds_fold", strategy_thirds),
97
+ ("vertical_fold", strategy_vertical),
98
+ ("mountain_fold", strategy_mountain),
99
+ ("accordion_3", strategy_accordion),
100
+ ("diagonal", strategy_diagonal),
101
+ ("quarter_fold", strategy_quarter),
102
+ ]
103
+
104
+
105
+ # ── Demo runner ────────────────────────────────────────────────────────────────
106
+
107
+ async def run_demo(task_name: str = "half_fold", delay_s: float = 0.5) -> None:
108
+ """Wait for server to be ready, then fire 8 episodes."""
109
+ # Give uvicorn time to bind and call startup hook (sets broadcast._loop)
110
+ await asyncio.sleep(1.5)
111
+
112
+ batch_id = 1
113
+ names, fns = zip(*STRATEGIES)
114
+ ep_ids = [f"ep_{name}" for name in names]
115
+
116
+ print(f"\n[demo] Starting batch {batch_id} — task: {task_name}")
117
+ print(f"[demo] Open http://localhost:9001/viewer/training.html\n")
118
+
119
+ # Signal grid to clear and show G=8
120
+ await broadcast.start_batch(batch_id, len(fns))
121
+
122
+ await asyncio.sleep(delay_s)
123
+
124
+ # Run all 8 episodes in the thread pool; broadcast_fn fires into this loop
125
+ results = await asyncio.gather(*[
126
+ asyncio.to_thread(
127
+ _run_one,
128
+ fn,
129
+ task_name,
130
+ ep_id,
131
+ broadcast.publish,
132
+ )
133
+ for fn, ep_id in zip(fns, ep_ids)
134
+ ])
135
+
136
+ scores = [r["score"] for r in results]
137
+ best_idx = max(range(len(scores)), key=lambda i: scores[i])
138
+
139
+ await broadcast.finish_batch(batch_id, scores, best_episode_id=ep_ids[best_idx])
140
+
141
+ print("\n[demo] Results:")
142
+ for name, result in zip(names, results):
143
+ print(f" {name:20s} score={result['score']:+.2f} status={result['status']}")
144
+ print(f"\n[demo] Best: {names[best_idx]} (score={scores[best_idx]:+.2f})")
145
+ print("\n[demo] Grid viewer running. Press Ctrl+C to stop.\n")
146
+
147
+
148
+ def _run_one(
149
+ strategy_fn: Callable,
150
+ task_name: str,
151
+ ep_id: str,
152
+ broadcast_fn: Callable,
153
+ ) -> dict:
154
+ """Thin wrapper: adds a small sleep between steps so the viewer can animate."""
155
+ from server.models import OrigamiAction
156
+ from server.origami_environment import OrigamiEnvironment
157
+
158
+ env = OrigamiEnvironment()
159
+ obs = env.reset(task_name=task_name)
160
+
161
+ broadcast_fn(ep_id, {
162
+ "type": "episode_update",
163
+ "episode_id": ep_id,
164
+ "task_name": task_name,
165
+ "step": 0,
166
+ "observation": _obs_dict(obs),
167
+ })
168
+
169
+ max_steps = env._task.get("max_folds", 10) if env._task else 10
170
+ status = "done"
171
+
172
+ for step_idx in range(max_steps):
173
+ if obs.done:
174
+ break
175
+
176
+ time.sleep(0.3) # pace so the viewer can animate each step
177
+
178
+ fold_dict = strategy_fn(obs.paper_state)
179
+
180
+ if fold_dict.get("type") == "stop":
181
+ break
182
+
183
+ action = OrigamiAction(
184
+ fold_type=fold_dict["type"],
185
+ fold_line=fold_dict["line"],
186
+ fold_angle=float(fold_dict.get("angle", 180.0)),
187
+ )
188
+ obs = env.step(action)
189
+
190
+ broadcast_fn(ep_id, {
191
+ "type": "episode_update",
192
+ "episode_id": ep_id,
193
+ "task_name": task_name,
194
+ "step": step_idx + 1,
195
+ "observation": _obs_dict(obs),
196
+ })
197
+
198
+ if obs.done:
199
+ break
200
+ else:
201
+ status = "timeout"
202
+
203
+ score = obs.reward if obs.reward is not None else env._total_reward or 0.0
204
+
205
+ broadcast_fn(ep_id, {
206
+ "type": "episode_done",
207
+ "episode_id": ep_id,
208
+ "status": status,
209
+ "score": float(score),
210
+ "final_metrics": obs.metrics,
211
+ })
212
+
213
+ return {
214
+ "episode_id": ep_id,
215
+ "score": float(score),
216
+ "final_metrics": obs.metrics,
217
+ "status": status,
218
+ }
219
+
220
+
221
+ def _obs_dict(obs) -> dict:
222
+ try:
223
+ return obs.model_dump()
224
+ except AttributeError:
225
+ return {
226
+ "paper_state": getattr(obs, "paper_state", {}),
227
+ "metrics": getattr(obs, "metrics", {}),
228
+ "fold_history": getattr(obs, "fold_history", []),
229
+ "done": getattr(obs, "done", False),
230
+ "reward": getattr(obs, "reward", None),
231
+ }
232
+
233
+
234
+ # ── Entry point ────────────────────────────────────────────────────────────────
235
+
236
+ async def _main() -> None:
237
+ config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
238
+ server = uvicorn.Server(config)
239
+
240
+ # Run demo concurrently with the uvicorn server
241
+ await asyncio.gather(
242
+ server.serve(),
243
+ run_demo(task_name="half_fold"),
244
+ )
245
+
246
+
247
+ if __name__ == "__main__":
248
+ try:
249
+ asyncio.run(_main())
250
+ except KeyboardInterrupt:
251
+ print("\n[demo] Stopped.")
training/demo_llm.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training/demo_llm.py — 8 rollouts using Claude as the zero-shot fold strategist.
3
+
4
+ Usage:
5
+ cd /path/to/optigami
6
+ ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
7
+
8
+ Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
9
+ Claude sees the current paper_state metrics and decides the next fold.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import os
16
+ import re
17
+ import time
18
+ from typing import Any
19
+
20
+ import anthropic
21
+ import uvicorn
22
+
23
+ from server.app import app, broadcast
24
+ from server.models import OrigamiAction
25
+ from server.origami_environment import OrigamiEnvironment
26
+ from server.tasks import get_task_by_name
27
+
28
+
29
+ TASK_NAME = "half_fold"
30
+ NUM_EPISODES = 8
31
+ MODEL = "claude-haiku-4-5-20251001"
32
+
33
+
34
+ # ── LLM strategy factory ───────────────────────────────────────────────────────
35
+
36
+ def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
37
+ """Return a strategy_fn for one episode. Each episode gets its own call history."""
38
+ history: list[dict[str, Any]] = []
39
+
40
+ def strategy(paper_state: dict) -> dict:
41
+ fold_count = paper_state.get("fold_count", 0)
42
+ compactness = paper_state.get("compactness", 0)
43
+ bb = paper_state.get("bounding_box", [1, 1, 0])
44
+ target_box = task.get("target_box", [1, 0.5, 0.02])
45
+ max_folds = task.get("max_folds", 3)
46
+
47
+ user_msg = f"""You are folding a {task['width']}x{task['height']} sheet of {task['material']}.
48
+ Task: {task['description']}
49
+ Target box to fit inside: {target_box}
50
+ Max folds allowed: {max_folds}
51
+
52
+ Current state (fold {fold_count}/{max_folds}):
53
+ compactness: {compactness:.3f} (1.0 = fully packed, 0.0 = flat)
54
+ bounding_box: [{bb[0]:.3f}, {bb[1]:.3f}, {bb[2]:.4f}]
55
+ fits_target_box: {paper_state.get('fits_target_box', False)}
56
+
57
+ Choose the next fold. Respond with ONLY valid JSON, no other text:
58
+ {{
59
+ "type": "valley" or "mountain" or "stop",
60
+ "line": {{"start": [x, y], "end": [x, y]}},
61
+ "angle": 180
62
+ }}
63
+
64
+ Coordinates are normalized 0-1. Use "stop" if done."""
65
+
66
+ history.append({"role": "user", "content": user_msg})
67
+
68
+ response = client.messages.create(
69
+ model=MODEL,
70
+ max_tokens=120,
71
+ messages=history,
72
+ )
73
+ reply = response.content[0].text.strip()
74
+ history.append({"role": "assistant", "content": reply})
75
+
76
+ # Extract JSON — handle markdown code blocks
77
+ match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
78
+ if not match:
79
+ return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
80
+
81
+ fold_dict = json.loads(match.group())
82
+ # Normalize: ensure required keys
83
+ fold_dict.setdefault("type", "valley")
84
+ fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
85
+ fold_dict.setdefault("angle", 180.0)
86
+ return fold_dict
87
+
88
+ return strategy
89
+
90
+
91
+ # ── Episode runner ─────────────────────────────────────────────────────────────
92
+
93
+ def run_episode_llm(
94
+ strategy_fn,
95
+ task_name: str,
96
+ ep_id: str,
97
+ broadcast_fn,
98
+ ) -> dict:
99
+ env = OrigamiEnvironment()
100
+ obs = env.reset(task_name=task_name)
101
+ task = env._task or {}
102
+
103
+ broadcast_fn(ep_id, {
104
+ "type": "episode_update",
105
+ "episode_id": ep_id,
106
+ "task_name": task_name,
107
+ "step": 0,
108
+ "observation": _obs_dict(obs),
109
+ })
110
+
111
+ max_steps = task.get("max_folds", 5)
112
+ status = "done"
113
+
114
+ for step_idx in range(max_steps):
115
+ if obs.done:
116
+ break
117
+
118
+ # Build a flat paper_state dict for the LLM (add metrics inline)
119
+ ps = dict(obs.paper_state)
120
+ ps.update(obs.metrics) # compactness, fits_target_box, etc.
121
+ ps["fold_count"] = step_idx
122
+
123
+ try:
124
+ fold_dict = strategy_fn(ps)
125
+ except Exception as exc:
126
+ broadcast_fn(ep_id, {
127
+ "type": "episode_done", "episode_id": ep_id,
128
+ "status": "error", "score": 0.0,
129
+ "final_metrics": obs.metrics, "error": str(exc),
130
+ })
131
+ return {"episode_id": ep_id, "score": 0.0, "status": "error"}
132
+
133
+ if fold_dict.get("type") == "stop":
134
+ break
135
+
136
+ time.sleep(0.4) # pace for viewer animation
137
+
138
+ action = OrigamiAction(
139
+ fold_type=fold_dict["type"],
140
+ fold_line=fold_dict["line"],
141
+ fold_angle=float(fold_dict.get("angle", 180.0)),
142
+ )
143
+ obs = env.step(action)
144
+
145
+ broadcast_fn(ep_id, {
146
+ "type": "episode_update",
147
+ "episode_id": ep_id,
148
+ "task_name": task_name,
149
+ "step": step_idx + 1,
150
+ "observation": _obs_dict(obs),
151
+ })
152
+
153
+ if obs.done:
154
+ break
155
+ else:
156
+ status = "timeout"
157
+
158
+ score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
159
+ broadcast_fn(ep_id, {
160
+ "type": "episode_done",
161
+ "episode_id": ep_id,
162
+ "status": status,
163
+ "score": float(score),
164
+ "final_metrics": obs.metrics,
165
+ })
166
+
167
+ return {"episode_id": ep_id, "score": float(score), "status": status}
168
+
169
+
170
+ def _obs_dict(obs) -> dict:
171
+ try:
172
+ return obs.model_dump()
173
+ except AttributeError:
174
+ return {
175
+ "paper_state": getattr(obs, "paper_state", {}),
176
+ "metrics": getattr(obs, "metrics", {}),
177
+ "fold_history": getattr(obs, "fold_history", []),
178
+ "done": getattr(obs, "done", False),
179
+ "reward": getattr(obs, "reward", None),
180
+ }
181
+
182
+
183
+ # ── Main ──────────────────────────────────────────────────────────────────────
184
+
185
+ async def run_demo() -> None:
186
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
187
+ if not api_key:
188
+ raise RuntimeError("Set ANTHROPIC_API_KEY environment variable")
189
+
190
+ client = anthropic.Anthropic(api_key=api_key)
191
+ task = get_task_by_name(TASK_NAME)
192
+
193
+ await asyncio.sleep(1.5) # wait for server startup
194
+
195
+ print(f"\n[llm-demo] Model: {MODEL}")
196
+ print(f"[llm-demo] Task: {TASK_NAME} — {task['description']}")
197
+ print(f"[llm-demo] Open: http://localhost:9001/viewer/training.html\n")
198
+
199
+ await broadcast.start_batch(1, NUM_EPISODES)
200
+
201
+ ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
202
+ strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
203
+
204
+ # Run all episodes concurrently (each makes its own Claude API calls)
205
+ results = await asyncio.gather(*[
206
+ asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
207
+ for fn, ep_id in zip(strategies, ep_ids)
208
+ ])
209
+
210
+ scores = [r["score"] for r in results]
211
+ best_idx = max(range(len(scores)), key=lambda i: scores[i])
212
+
213
+ await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
214
+
215
+ print("\n[llm-demo] Results:")
216
+ for i, result in enumerate(results):
217
+ print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']}")
218
+ print(f"\n[llm-demo] Best: ep_{best_idx:02d} (score={scores[best_idx]:+.2f})")
219
+ print("\n[llm-demo] Press Ctrl+C to stop.\n")
220
+
221
+
222
+ async def _main() -> None:
223
+ config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
224
+ server = uvicorn.Server(config)
225
+ await asyncio.gather(server.serve(), run_demo())
226
+
227
+
228
+ if __name__ == "__main__":
229
+ try:
230
+ asyncio.run(_main())
231
+ except KeyboardInterrupt:
232
+ print("\n[llm-demo] Stopped.")
training/runner.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TrainingRunner — parallel episode executor for GRPO training.
3
+
4
+ Each episode runs in a ThreadPoolExecutor thread.
5
+ After every env.step(), observations are pushed to the broadcast server (fire-and-forget).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import uuid
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from typing import Any, Callable, Optional
12
+
13
+ from server.models import OrigamiAction
14
+ from server.origami_environment import OrigamiEnvironment
15
+
16
+
17
+ BroadcastFn = Callable[[str, dict], None]
18
+
19
+
20
+ def run_episode(
21
+ strategy_fn: Callable[[dict], dict],
22
+ task_name: str,
23
+ ep_id: Optional[str] = None,
24
+ broadcast_fn: Optional[BroadcastFn] = None,
25
+ max_steps: Optional[int] = None,
26
+ ) -> dict:
27
+ """Run a single origami episode with a given strategy function.
28
+
29
+ Args:
30
+ strategy_fn: Callable that receives paper_state dict and returns a fold dict:
31
+ {"type": "valley"|"mountain"|"pleat"|"crimp"|"stop",
32
+ "line": {"start": [x, y], "end": [x, y]},
33
+ "angle": 180.0}
34
+ task_name: Name of the task (from server/tasks.py)
35
+ ep_id: Episode identifier for broadcast; auto-generated if None
36
+ broadcast_fn: Optional callback(ep_id, data) for live streaming
37
+ max_steps: Override task's max_folds if provided
38
+
39
+ Returns:
40
+ dict with keys: episode_id, score, final_metrics, fold_history, status
41
+ """
42
+ ep_id = ep_id or str(uuid.uuid4())[:8]
43
+ env = OrigamiEnvironment()
44
+
45
+ obs = env.reset(task_name=task_name)
46
+
47
+ if broadcast_fn:
48
+ broadcast_fn(ep_id, {
49
+ "type": "episode_update",
50
+ "episode_id": ep_id,
51
+ "task_name": task_name,
52
+ "step": 0,
53
+ "observation": _obs_to_dict(obs),
54
+ })
55
+
56
+ step_limit = max_steps or env._task.get("max_folds", 20) if env._task else 20
57
+ status = "done"
58
+
59
+ for step_idx in range(step_limit):
60
+ if obs.done:
61
+ break
62
+
63
+ # Strategy generates a fold dict
64
+ try:
65
+ fold_dict = strategy_fn(obs.paper_state)
66
+ except Exception as exc:
67
+ status = "error"
68
+ if broadcast_fn:
69
+ broadcast_fn(ep_id, {
70
+ "type": "episode_done",
71
+ "episode_id": ep_id,
72
+ "status": "error",
73
+ "score": obs.reward or 0.0,
74
+ "final_metrics": obs.metrics,
75
+ "error": str(exc),
76
+ })
77
+ break
78
+
79
+ fold_type = fold_dict.get("type", "valley")
80
+ fold_line = fold_dict.get("line", {"start": [0, 0.5], "end": [1, 0.5]})
81
+ fold_angle = float(fold_dict.get("angle", 180.0))
82
+
83
+ action = OrigamiAction(
84
+ fold_type=fold_type,
85
+ fold_line=fold_line,
86
+ fold_angle=fold_angle,
87
+ )
88
+ obs = env.step(action)
89
+
90
+ if broadcast_fn:
91
+ broadcast_fn(ep_id, {
92
+ "type": "episode_update",
93
+ "episode_id": ep_id,
94
+ "task_name": task_name,
95
+ "step": step_idx + 1,
96
+ "observation": _obs_to_dict(obs),
97
+ })
98
+
99
+ if obs.done:
100
+ break
101
+ else:
102
+ status = "timeout"
103
+
104
+ score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
105
+
106
+ if broadcast_fn:
107
+ broadcast_fn(ep_id, {
108
+ "type": "episode_done",
109
+ "episode_id": ep_id,
110
+ "status": status,
111
+ "score": float(score),
112
+ "final_metrics": obs.metrics,
113
+ })
114
+
115
+ return {
116
+ "episode_id": ep_id,
117
+ "score": float(score),
118
+ "final_metrics": obs.metrics,
119
+ "fold_history": obs.fold_history,
120
+ "status": status,
121
+ }
122
+
123
+
124
+ def run_batch(
125
+ strategy_fns: list[Callable[[dict], dict]],
126
+ task_name: str,
127
+ broadcast_fn: Optional[BroadcastFn] = None,
128
+ batch_id: Optional[int] = None,
129
+ max_workers: int = 8,
130
+ ) -> list[dict]:
131
+ """Run G episodes in parallel with a ThreadPoolExecutor.
132
+
133
+ Args:
134
+ strategy_fns: List of G strategy callables (one per completion)
135
+ task_name: Task to use for all episodes
136
+ broadcast_fn: Optional broadcast callback, called after each step
137
+ batch_id: Batch identifier for broadcast
138
+ max_workers: Max parallel threads (bounded by G)
139
+
140
+ Returns:
141
+ List of episode result dicts, in same order as strategy_fns
142
+ """
143
+ n = len(strategy_fns)
144
+ ep_ids = [f"ep_{(batch_id or 0):04d}_{i:02d}" for i in range(n)]
145
+ workers = min(max_workers, n)
146
+
147
+ results: list[dict] = [{}] * n
148
+
149
+ with ThreadPoolExecutor(max_workers=workers) as pool:
150
+ futures = {
151
+ pool.submit(
152
+ run_episode,
153
+ fn,
154
+ task_name,
155
+ ep_ids[i],
156
+ broadcast_fn,
157
+ ): i
158
+ for i, fn in enumerate(strategy_fns)
159
+ }
160
+
161
+ for future in as_completed(futures):
162
+ idx = futures[future]
163
+ try:
164
+ results[idx] = future.result()
165
+ except Exception as exc:
166
+ results[idx] = {
167
+ "episode_id": ep_ids[idx],
168
+ "score": 0.0,
169
+ "final_metrics": {},
170
+ "fold_history": [],
171
+ "status": "error",
172
+ "error": str(exc),
173
+ }
174
+
175
+ return results
176
+
177
+
178
+ def _obs_to_dict(obs) -> dict:
179
+ """Convert OrigamiObservation to a JSON-serializable dict."""
180
+ try:
181
+ return obs.model_dump()
182
+ except AttributeError:
183
+ return {
184
+ "task": obs.task if hasattr(obs, "task") else {},
185
+ "paper_state": obs.paper_state if hasattr(obs, "paper_state") else {},
186
+ "metrics": obs.metrics if hasattr(obs, "metrics") else {},
187
+ "fold_history": obs.fold_history if hasattr(obs, "fold_history") else [],
188
+ "done": obs.done if hasattr(obs, "done") else False,
189
+ "reward": obs.reward if hasattr(obs, "reward") else None,
190
+ "error": obs.error if hasattr(obs, "error") else None,
191
+ }