sissississi ianalin123 commited on
Commit
1e49495
·
1 Parent(s): 2c8a058

new-environment (#5)

Browse files

- docs/handoff (e4d16d926a54cdfbad613d801124bef6cd53f17e)
- plans/ (39c6d2318faab2691566c378dca0bd9749304b79)
- feat: implement origami RL environment (Phase 1) (c44bdad79837ff4636d2e8902eea99065fd60681)
- feat: React observability dashboard + FastAPI server + matplotlib renderer (25db0fc5dc0c74bd090c28bd0153c91148342daa)
- feat: Python 3D origami mass-spring simulator (Ghassaei 2018) (dc79e2a71ff693ac301e0bb5553606450cfe3cae)
- Add 3D fold preview modes (3744ef301f1559b3b51ddf7aac6dc32ad689fdff)
- research (608285dd42d04ff717578322c88f527fed3fd317)
- Merge branch 'main' of https://huggingface.co/spaces/openenv-community/optigami (d2552c731aa2683cc7db6ffb6d82c2539b6d2398)
- chore: resolve merge conflicts — keep HF deployment fixes (f8d2bab36c7f83aaf21d2665fa5d5e6e24ada2bc)
- feat: update engine modules, remove research 2 docs (ca61c8d786e330b35c89a1b9bc3d5d518843cdfb)
- fix: rename server.py to server_legacy.py, add server/ package (0bcd0b11a6083cba1e78a82522594f478f937559)
- refactor(server): migrate demo routes to server/ task+env API (c46fef811c5bc6628ceebf2566b0863f32234c92)
- fix(canvas): adapt CreaseCanvas to engine FOLD-format paper_state (8da8915849405917b4292e55f2272200c485b612)
- feat(3d): render engine paper_state directly with strain heatmap (f6709d8692a6b631528abbaf7ecda38d5aac5123)
- refactor(metrics): replace reward decomposition with engine metrics (d091b7773f66c2a8211c165d1dbb49f852f3976f)
- refactor(app): update App and StepFeed for new fold/metrics schema (8cc1585c4c26acac14d77c79e30a662605867e38)
- Merge branch 'main' of https://huggingface.co/spaces/openenv-community/optigami (5eca717229bb839f3a4caa2f6ad8a23c93d28c4d)
- feat(server): add training broadcast server and Colab training FastAPI app (6cf63a9c454cabc3b4a3076aeed29015a21fd225)
- feat(training): add parallel episode runner and demo scripts (a884e864be73e6073366a57de1d585458a1f7688)
- feat(viewer): add training grid viewer HTML (c4160923af5fb634f948070712309b407444df6c)
- feat(frontend): replay mode, camera angle fix, endpoint alignment (9221fb1f5f545fdeead600d2df6260d049c1e291)


Co-authored-by: Iana Lin <ianalin123@users.noreply.huggingface.co>

engine/fold_engine.py CHANGED
@@ -151,6 +151,8 @@ def apply_fold(
151
  elif face_sides[i] == "fixed" and face_sides[j] == "rotated":
152
  new_paper.face_orders.append((j, i, 1))
153
 
 
 
154
  return new_paper, None
155
 
156
 
@@ -205,3 +207,43 @@ def execute_fold_strategy(
205
  applied.append(fold)
206
 
207
  return current, applied, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  elif face_sides[i] == "fixed" and face_sides[j] == "rotated":
152
  new_paper.face_orders.append((j, i, 1))
153
 
154
+ new_paper.fold_count += 1
155
+
156
  return new_paper, None
157
 
158
 
 
207
  applied.append(fold)
208
 
209
  return current, applied, None
210
+
211
+
212
+ def apply_pleat(
213
+ paper: Paper,
214
+ line1: dict,
215
+ line2: dict,
216
+ angle: float = 180.0,
217
+ ) -> tuple[Paper, str | None]:
218
+ """Pleat fold: valley at line1, mountain at line2 (two parallel folds).
219
+
220
+ Both line dicts have the form: {"start": [x, y], "end": [x, y]}
221
+ Returns (new_paper, error_or_None).
222
+ """
223
+ paper, err = apply_fold(paper, {"type": "valley", "line": line1, "angle": angle})
224
+ if err:
225
+ return paper, f"Pleat valley fold failed: {err}"
226
+ paper, err = apply_fold(paper, {"type": "mountain", "line": line2, "angle": angle})
227
+ if err:
228
+ return paper, f"Pleat mountain fold failed: {err}"
229
+ return paper, None
230
+
231
+
232
+ def apply_crimp(
233
+ paper: Paper,
234
+ line1: dict,
235
+ line2: dict,
236
+ angle: float = 180.0,
237
+ ) -> tuple[Paper, str | None]:
238
+ """Crimp fold: mountain at line1, valley at line2 (reverse of pleat).
239
+
240
+ Both line dicts have the form: {"start": [x, y], "end": [x, y]}
241
+ Returns (new_paper, error_or_None).
242
+ """
243
+ paper, err = apply_fold(paper, {"type": "mountain", "line": line1, "angle": angle})
244
+ if err:
245
+ return paper, f"Crimp mountain fold failed: {err}"
246
+ paper, err = apply_fold(paper, {"type": "valley", "line": line2, "angle": angle})
247
+ if err:
248
+ return paper, f"Crimp valley fold failed: {err}"
249
+ return paper, None
engine/metrics.py CHANGED
@@ -102,3 +102,130 @@ def compute_metrics(paper: Paper, original_paper: Paper | None = None) -> dict:
102
  "num_faces": len(paper.faces),
103
  "num_layers": paper.num_layers,
104
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  "num_faces": len(paper.faces),
103
  "num_layers": paper.num_layers,
104
  }
105
+
106
+
107
+ def compute_all_metrics(paper, task: dict, validation: dict) -> dict:
108
+ """Compute every metric and return a flat dict.
109
+
110
+ Called after physics + validation. Combines validity, compactness,
111
+ structural, efficiency, and deployability metrics.
112
+
113
+ Parameters
114
+ ----------
115
+ paper : Paper
116
+ Current paper state (after simulate()).
117
+ task : dict
118
+ Task definition with keys: width, height, target_ratio, target_box, must_deploy.
119
+ validation : dict
120
+ Output of validate_state(paper).
121
+ """
122
+ import numpy as np
123
+
124
+ bb = paper.bounding_box # (3,) array
125
+ original_area = paper.original_area if paper.original_area > 0 else (paper.material.thickness_mm / 1000.0)
126
+ t = paper.material.thickness_mm / 1000.0
127
+ original_bbox_vol = original_area * t
128
+ folded_bbox_vol = float(bb[0] * bb[1] * bb[2]) if bb[2] > 0 else float(bb[0] * bb[1] * t)
129
+
130
+ # ── Folded area (XY footprint) ────────────────────────────────
131
+ if len(paper.vertices) >= 3:
132
+ try:
133
+ from scipy.spatial import ConvexHull
134
+ hull = ConvexHull(paper.vertices[:, :2])
135
+ folded_area = float(hull.volume)
136
+ except Exception:
137
+ ptp = np.ptp(paper.vertices[:, :2], axis=0)
138
+ folded_area = float(ptp[0] * ptp[1])
139
+ else:
140
+ folded_area = original_area
141
+
142
+ deployment_ratio = folded_area / original_area if original_area > 0 else 1.0
143
+ compactness = 1.0 - deployment_ratio
144
+ volume_compaction = folded_bbox_vol / original_bbox_vol if original_bbox_vol > 0 else 1.0
145
+ material_volume = original_area * t
146
+ packing_efficiency = material_volume / folded_bbox_vol if folded_bbox_vol > 0 else 0.0
147
+
148
+ # ── Target box check ─────────────────────────────────────────
149
+ target_box = task.get("target_box")
150
+ fits_target_box = False
151
+ if target_box and len(target_box) == 3:
152
+ fits_target_box = bool(
153
+ bb[0] <= target_box[0] + 1e-6 and
154
+ bb[1] <= target_box[1] + 1e-6 and
155
+ bb[2] <= target_box[2] + 1e-6
156
+ )
157
+
158
+ # ── Strain ───────────────────────────────────────────────────
159
+ strain = paper.strain_per_vertex
160
+ max_strain = float(np.max(strain)) if len(strain) > 0 else 0.0
161
+ mean_strain = float(np.mean(strain)) if len(strain) > 0 else 0.0
162
+
163
+ # ── Energy ───────────────────────────────────────────────────
164
+ energy = paper.energy
165
+
166
+ # ── Efficiency ───────────────────────────────────────────────
167
+ fold_count = paper.fold_count
168
+
169
+ # Crease complexity: entropy of M/V assignment distribution
170
+ mv_assignments = [a for a in paper.assignments if a in ("M", "V")]
171
+ if mv_assignments:
172
+ total = len(mv_assignments)
173
+ m_count = mv_assignments.count("M")
174
+ v_count = mv_assignments.count("V")
175
+ p_m = m_count / total if total > 0 else 0
176
+ p_v = v_count / total if total > 0 else 0
177
+ crease_complexity = 0.0
178
+ if p_m > 0:
179
+ crease_complexity -= p_m * np.log2(p_m)
180
+ if p_v > 0:
181
+ crease_complexity -= p_v * np.log2(p_v)
182
+ else:
183
+ crease_complexity = 0.0
184
+
185
+ folding_efficiency = compactness / max(fold_count, 1)
186
+
187
+ # ── Deployability ─────────────────────────────────────────────
188
+ must_deploy = task.get("must_deploy", False)
189
+ # Simple deployability heuristic: if valid and compactness > 0, assume deployable
190
+ is_deployable = bool(validation.get("is_valid", False) and compactness > 0.01) if must_deploy else None
191
+ # Deployment force estimate from total energy gradient (rough)
192
+ deployment_force_estimate = float(energy.get("fold", 0.0)) / max(paper.original_area, 1e-6)
193
+
194
+ return {
195
+ # Validity (from validation dict)
196
+ "is_valid": validation.get("is_valid", False),
197
+ "kawasaki_violations": validation.get("kawasaki_violations", 0),
198
+ "kawasaki_total_error": validation.get("kawasaki_total_error", 0.0),
199
+ "maekawa_violations": validation.get("maekawa_violations", 0),
200
+ "self_intersections": validation.get("self_intersections", 0),
201
+ "strain_exceeded": validation.get("strain_exceeded", False),
202
+
203
+ # Compactness
204
+ "deployment_ratio": float(deployment_ratio),
205
+ "compactness": float(compactness),
206
+ "volume_compaction": float(volume_compaction),
207
+ "packing_efficiency": float(packing_efficiency),
208
+ "fits_target_box": fits_target_box,
209
+ "bounding_box": bb.tolist(),
210
+
211
+ # Structural
212
+ "max_strain": max_strain,
213
+ "mean_strain": mean_strain,
214
+ "total_energy": float(energy.get("total", 0.0)),
215
+ "energy_bar": float(energy.get("bar", 0.0)),
216
+ "energy_facet": float(energy.get("facet", 0.0)),
217
+ "energy_fold": float(energy.get("fold", 0.0)),
218
+
219
+ # Efficiency
220
+ "fold_count": fold_count,
221
+ "folding_efficiency": float(folding_efficiency),
222
+ "crease_complexity": float(crease_complexity),
223
+
224
+ # Deployability
225
+ "is_deployable": is_deployable,
226
+ "deployment_force_estimate": float(deployment_force_estimate),
227
+
228
+ # Shape similarity placeholders
229
+ "chamfer_distance": None,
230
+ "hausdorff_distance": None,
231
+ }
engine/paper.py CHANGED
@@ -89,6 +89,10 @@ class Paper:
89
  material: Material = field(default_factory=lambda: get_material("paper"))
90
  rest_lengths: np.ndarray = field(default_factory=lambda: np.empty(0))
91
  original_area: float = 0.0
 
 
 
 
92
 
93
  # ── constructors ────────────────────────────────────────────────
94
 
@@ -125,7 +129,7 @@ class Paper:
125
  dtype=np.float64,
126
  )
127
 
128
- return Paper(
129
  vertices=verts,
130
  edges=edges,
131
  faces=faces,
@@ -135,6 +139,8 @@ class Paper:
135
  rest_lengths=rest_lengths,
136
  original_area=width * height,
137
  )
 
 
138
 
139
  # ── dict / prompt serialization (matches mock_env.PaperState.to_dict) ──
140
 
@@ -165,6 +171,33 @@ class Paper:
165
  },
166
  }
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # ── FOLD format serialization ───────────────────────────────────
169
 
170
  def to_fold_json(self) -> str:
@@ -485,4 +518,8 @@ class Paper:
485
  ),
486
  rest_lengths=self.rest_lengths.copy(),
487
  original_area=self.original_area,
 
 
 
 
488
  )
 
89
  material: Material = field(default_factory=lambda: get_material("paper"))
90
  rest_lengths: np.ndarray = field(default_factory=lambda: np.empty(0))
91
  original_area: float = 0.0
92
+ rest_positions: np.ndarray = field(default_factory=lambda: np.empty((0, 3)))
93
+ strain_per_vertex: np.ndarray = field(default_factory=lambda: np.empty(0))
94
+ energy: dict = field(default_factory=lambda: {"total": 0.0, "bar": 0.0, "facet": 0.0, "fold": 0.0})
95
+ fold_count: int = 0
96
 
97
  # ── constructors ────────────────────────────────────────────────
98
 
 
129
  dtype=np.float64,
130
  )
131
 
132
+ paper = Paper(
133
  vertices=verts,
134
  edges=edges,
135
  faces=faces,
 
139
  rest_lengths=rest_lengths,
140
  original_area=width * height,
141
  )
142
+ paper.rest_positions = verts.copy()
143
+ return paper
144
 
145
  # ── dict / prompt serialization (matches mock_env.PaperState.to_dict) ──
146
 
 
171
  },
172
  }
173
 
174
+ def to_observation_dict(self) -> dict:
175
+ bb = self.bounding_box
176
+ return {
177
+ "vertices_coords": self.vertices.tolist(),
178
+ "edges_vertices": self.edges.tolist(),
179
+ "faces_vertices": self.faces,
180
+ "edges_assignment": list(self.assignments),
181
+ "edges_foldAngle": self.fold_angles.tolist(),
182
+ "num_vertices": len(self.vertices),
183
+ "num_edges": len(self.edges),
184
+ "num_faces": len(self.faces),
185
+ "bounding_box": bb.tolist(),
186
+ "num_layers": self.num_layers,
187
+ "material": {
188
+ "name": self.material.name,
189
+ "thickness_mm": self.material.thickness_mm,
190
+ "youngs_modulus_gpa": self.material.youngs_modulus_gpa,
191
+ "max_strain": self.material.max_strain,
192
+ "poisson_ratio": self.material.poissons_ratio,
193
+ },
194
+ "strain_per_vertex": self.strain_per_vertex.tolist(),
195
+ "energy": dict(self.energy),
196
+ "fold_count": self.fold_count,
197
+ "width": float(self.original_area ** 0.5) if self.original_area > 0 else 1.0,
198
+ "height": float(self.original_area ** 0.5) if self.original_area > 0 else 1.0,
199
+ }
200
+
201
  # ── FOLD format serialization ───────────────────────────────────
202
 
203
  def to_fold_json(self) -> str:
 
518
  ),
519
  rest_lengths=self.rest_lengths.copy(),
520
  original_area=self.original_area,
521
+ rest_positions=self.rest_positions.copy(),
522
+ strain_per_vertex=self.strain_per_vertex.copy(),
523
+ energy=dict(self.energy),
524
+ fold_count=self.fold_count,
525
  )
engine/physics.py CHANGED
@@ -255,3 +255,263 @@ def _face_normal(verts: np.ndarray, face: list[int]) -> np.ndarray | None:
255
  if norm < 1e-15:
256
  return None
257
  return normal / norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  if norm < 1e-15:
256
  return None
257
  return normal / norm
258
+
259
+
260
+ # ────────────────────────────────────────────────────────────────────
261
+ # Topology precomputation
262
+ # ────────────────────────────────────────────────────────────────────
263
+
264
+ def build_beam_list(paper: Paper) -> list[tuple[int, int, float, float]]:
265
+ """Build list of (node_a, node_b, rest_len, k_axial) for every edge.
266
+
267
+ Uses normalized stiffness values (arch doc constants) scaled by material
268
+ Young's modulus ratio — keeps the Verlet integrator stable at unit scale.
269
+ """
270
+ # Normalized stiffness constants (arch doc values)
271
+ K_AXIAL_BASE = 70.0
272
+ # Scale by material: paper (3 GPa) = 1.0 baseline
273
+ mat = paper.material
274
+ E_ratio = mat.youngs_modulus_gpa / 3.0
275
+ k_axial = K_AXIAL_BASE * E_ratio
276
+
277
+ beams = []
278
+ for ei, (v1, v2) in enumerate(paper.edges):
279
+ L0 = paper.rest_lengths[ei]
280
+ beams.append((int(v1), int(v2), float(L0), float(k_axial)))
281
+ return beams
282
+
283
+
284
+ def build_crease_list(paper: Paper) -> list[tuple[int, int, int, int, float, float, str]]:
285
+ """Build list of (n1, n2, n3, n4, target_angle_rad, k, type) for each crease hinge.
286
+
287
+ Each hinge is defined by 4 nodes: n1-n2 is the hinge edge, n3 and n4 are
288
+ the wing-tip nodes of the two adjacent faces.
289
+ type is 'fold' (M/V crease) or 'facet' (interior flat edge).
290
+ """
291
+ verts = paper.vertices
292
+
293
+ # Build edge → face adjacency
294
+ edge_faces: dict[int, list[int]] = {}
295
+ for fi, face in enumerate(paper.faces):
296
+ n = len(face)
297
+ for k in range(n):
298
+ va, vb = face[k], face[(k + 1) % n]
299
+ for ei, e in enumerate(paper.edges):
300
+ if (e[0] == va and e[1] == vb) or (e[0] == vb and e[1] == va):
301
+ edge_faces.setdefault(ei, []).append(fi)
302
+ break
303
+
304
+ creases = []
305
+ for ei, adj in edge_faces.items():
306
+ if len(adj) < 2:
307
+ continue
308
+ f1, f2 = adj[0], adj[1]
309
+ face1, face2 = paper.faces[f1], paper.faces[f2]
310
+ n1, n2 = int(paper.edges[ei][0]), int(paper.edges[ei][1])
311
+
312
+ # Find wing-tip nodes (in each face, the vertex NOT on the shared edge)
313
+ wing1 = [v for v in face1 if v != n1 and v != n2]
314
+ wing2 = [v for v in face2 if v != n1 and v != n2]
315
+ if not wing1 or not wing2:
316
+ continue
317
+ n3, n4 = int(wing1[0]), int(wing2[0])
318
+
319
+ # Normalized stiffness constants (arch doc values), scaled by material
320
+ E_ratio = paper.material.youngs_modulus_gpa / 3.0
321
+ K_FACET = 0.2 * E_ratio
322
+ K_FOLD = 0.7 * E_ratio
323
+
324
+ asgn = paper.assignments[ei]
325
+ if asgn in ("M", "V"):
326
+ target = float(np.radians(paper.fold_angles[ei]))
327
+ k = K_FOLD
328
+ ctype = "fold"
329
+ else:
330
+ target = float(np.pi)
331
+ k = K_FACET
332
+ ctype = "facet"
333
+
334
+ creases.append((n1, n2, n3, n4, target, k, ctype))
335
+ return creases
336
+
337
+
338
+ def _torque_to_forces(
339
+ p1: np.ndarray, p2: np.ndarray,
340
+ p3: np.ndarray, p4: np.ndarray,
341
+ torque: float,
342
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
343
+ """Convert a dihedral torque into forces on the 4 hinge nodes.
344
+
345
+ p1-p2 is the hinge edge. p3 and p4 are wing tips.
346
+ Returns (f1, f2, f3, f4) as (3,) arrays.
347
+ """
348
+ e = p2 - p1
349
+ e_len = np.linalg.norm(e)
350
+ if e_len < 1e-12:
351
+ zero = np.zeros(3)
352
+ return zero, zero, zero, zero
353
+
354
+ e_hat = e / e_len
355
+
356
+ # Perpendicular components of wing vectors relative to hinge
357
+ d3 = p3 - p1
358
+ d4 = p4 - p1
359
+ d3_perp = d3 - np.dot(d3, e_hat) * e_hat
360
+ d4_perp = d4 - np.dot(d4, e_hat) * e_hat
361
+
362
+ len3 = np.linalg.norm(d3_perp)
363
+ len4 = np.linalg.norm(d4_perp)
364
+
365
+ if len3 < 1e-12 or len4 < 1e-12:
366
+ zero = np.zeros(3)
367
+ return zero, zero, zero, zero
368
+
369
+ # Force on wing tips proportional to torque / lever arm
370
+ f3 = torque / (len3 * e_len) * np.cross(e_hat, d3_perp / len3)
371
+ f4 = -torque / (len4 * e_len) * np.cross(e_hat, d4_perp / len4)
372
+
373
+ # Reaction forces distributed to hinge nodes
374
+ f1 = -(f3 + f4) * 0.5
375
+ f2 = -(f3 + f4) * 0.5
376
+
377
+ return f1, f2, f3, f4
378
+
379
+
380
+ # ────────────────────────────────────────────────────────────────────
381
+ # Verlet solver
382
+ # ────────────────────────────────────────────────────────────────────
383
+
384
+ def simulate(
385
+ paper: Paper,
386
+ fold_percent: float = 1.0,
387
+ n_steps: int = 500,
388
+ dt: float = 0.005,
389
+ damping: float = 0.15,
390
+ ) -> Paper:
391
+ """Run bar-and-hinge Verlet integration to relax the mesh.
392
+
393
+ Updates paper.vertices, paper.strain_per_vertex, and paper.energy in-place.
394
+ Returns the mutated paper for chaining.
395
+
396
+ Parameters
397
+ ----------
398
+ paper : Paper
399
+ Paper state after a fold has been applied (vertices already rotated).
400
+ fold_percent : float
401
+ How far along the fold to drive (0=flat, 1=full target angle).
402
+ n_steps : int
403
+ Maximum integration steps.
404
+ dt : float
405
+ Time step. Keep small (0.005) for stability with stiff materials.
406
+ damping : float
407
+ Velocity damping coefficient (0=undamped, 1=fully damped).
408
+ """
409
+ if len(paper.vertices) == 0:
410
+ return paper
411
+
412
+ beams = build_beam_list(paper)
413
+ creases = build_crease_list(paper)
414
+
415
+ pos = paper.vertices.copy() # (N, 3) current positions
416
+ last_pos = pos.copy() # (N, 3) previous positions (Verlet)
417
+
418
+ max_force_cap = 1e6 # prevent runaway forces
419
+
420
+ for _ in range(n_steps):
421
+ forces = np.zeros_like(pos)
422
+
423
+ # ── Beam (axial spring) forces ───────────────────────────────
424
+ for (a, b, L0, k) in beams:
425
+ delta = pos[b] - pos[a]
426
+ L = np.linalg.norm(delta)
427
+ if L < 1e-12:
428
+ continue
429
+ strain = (L - L0) / L0
430
+ F_mag = k * strain
431
+ F_vec = F_mag * (delta / L)
432
+ # Clamp to prevent instability
433
+ F_vec = np.clip(F_vec, -max_force_cap, max_force_cap)
434
+ forces[a] += F_vec
435
+ forces[b] -= F_vec
436
+
437
+ # ── Crease (dihedral spring) forces ─────────────────────────
438
+ for (n1, n2, n3, n4, target, k, ctype) in creases:
439
+ actual_target = target * fold_percent if ctype == "fold" else target
440
+ try:
441
+ theta = _compute_dihedral_rad(pos[n1], pos[n2], pos[n3], pos[n4])
442
+ except Exception:
443
+ continue
444
+ delta_theta = theta - actual_target
445
+ edge_len = np.linalg.norm(pos[n2] - pos[n1])
446
+ torque = k * edge_len * delta_theta
447
+ torque = float(np.clip(torque, -max_force_cap, max_force_cap))
448
+
449
+ f1, f2, f3, f4 = _torque_to_forces(
450
+ pos[n1], pos[n2], pos[n3], pos[n4], torque
451
+ )
452
+ forces[n1] += np.clip(f1, -max_force_cap, max_force_cap)
453
+ forces[n2] += np.clip(f2, -max_force_cap, max_force_cap)
454
+ forces[n3] += np.clip(f3, -max_force_cap, max_force_cap)
455
+ forces[n4] += np.clip(f4, -max_force_cap, max_force_cap)
456
+
457
+ # ── Verlet integration ───────────────────────────────────────
458
+ new_pos = pos + (1.0 - damping) * (pos - last_pos) + forces * (dt * dt)
459
+
460
+ # NaN guard
461
+ if np.any(np.isnan(new_pos)):
462
+ break
463
+
464
+ last_pos = pos
465
+ pos = new_pos
466
+
467
+ # ── Convergence check ────────────────────────────────────────
468
+ kinetic = np.sum((pos - last_pos) ** 2)
469
+ if kinetic < 1e-12:
470
+ break
471
+
472
+ # ── Write results back to paper ──────────────────────────────────
473
+ paper.vertices = pos
474
+ paper.strain_per_vertex = compute_strain(paper)
475
+ paper.energy = {
476
+ "total": compute_total_energy(paper),
477
+ "bar": compute_bar_energy(paper),
478
+ "facet": compute_facet_energy(paper),
479
+ "fold": compute_fold_energy(paper),
480
+ }
481
+
482
+ return paper
483
+
484
+
485
+ def _compute_dihedral_rad(
486
+ p1: np.ndarray, p2: np.ndarray,
487
+ p3: np.ndarray, p4: np.ndarray,
488
+ ) -> float:
489
+ """Dihedral angle in radians between planes (p1,p2,p3) and (p1,p2,p4).
490
+
491
+ p1-p2 is the hinge edge. p3 and p4 are the wing tips.
492
+ Returns angle in [0, 2*pi).
493
+ """
494
+ e = p2 - p1
495
+ e_norm = np.linalg.norm(e)
496
+ if e_norm < 1e-12:
497
+ return float(np.pi)
498
+ e_hat = e / e_norm
499
+
500
+ n1 = np.cross(p3 - p1, e)
501
+ n2 = np.cross(e, p4 - p1)
502
+ len1 = np.linalg.norm(n1)
503
+ len2 = np.linalg.norm(n2)
504
+ if len1 < 1e-12 or len2 < 1e-12:
505
+ return float(np.pi)
506
+
507
+ n1 = n1 / len1
508
+ n2 = n2 / len2
509
+
510
+ cos_a = float(np.clip(np.dot(n1, n2), -1.0, 1.0))
511
+ angle = np.arccos(cos_a)
512
+
513
+ cross = np.cross(n1, n2)
514
+ if np.dot(cross, e_hat) < 0:
515
+ angle = 2.0 * np.pi - angle
516
+
517
+ return float(angle)
engine/validation.py CHANGED
@@ -254,3 +254,25 @@ def validate_paper(paper: Paper) -> ValidationResult:
254
  self_intersection_count=si_count,
255
  is_valid=k_valid and m_valid and si_valid,
256
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  self_intersection_count=si_count,
255
  is_valid=k_valid and m_valid and si_valid,
256
  )
257
+
258
+
259
+ def validate_state(paper: Paper) -> dict:
260
+ """Run all validation checks and return a flat dict.
261
+
262
+ This is the interface used by OrigamiEnvironment. It calls the
263
+ existing validation functions and returns a dict with all fields
264
+ the environment and metrics system need.
265
+ """
266
+ result = validate_paper(paper)
267
+ strain_exceeded = bool(
268
+ len(paper.strain_per_vertex) > 0
269
+ and float(paper.strain_per_vertex.max()) > paper.material.max_strain
270
+ )
271
+ return {
272
+ "is_valid": result.is_valid and not strain_exceeded,
273
+ "kawasaki_violations": int(not result.kawasaki_valid),
274
+ "kawasaki_total_error": float(result.kawasaki_violation),
275
+ "maekawa_violations": int(not result.maekawa_valid),
276
+ "self_intersections": result.self_intersection_count,
277
+ "strain_exceeded": strain_exceeded,
278
+ }
openenv_server/app.py CHANGED
@@ -19,123 +19,116 @@ app = create_app(
19
 
20
 
21
  # ---------------------------------------------------------------------------
22
- # Demo routes required by the React frontend.
23
- # These must be registered BEFORE the StaticFiles catch-all mount.
24
  # ---------------------------------------------------------------------------
25
 
26
- DEMO_COMPLETIONS: dict[str, str] = {
27
- "half_horizontal": '<folds>[{"instruction": "Valley fold along horizontal center line", "from": [0, 0.5], "to": [1, 0.5], "assignment": "V"}]</folds>',
28
- "half_vertical": '<folds>[{"instruction": "Mountain fold along vertical center line", "from": [0.5, 0], "to": [0.5, 1], "assignment": "M"}]</folds>',
29
- "diagonal_main": '<folds>[{"instruction": "Valley fold along main diagonal", "from": [0, 0], "to": [1, 1], "assignment": "V"}]</folds>',
30
- "diagonal_anti": '<folds>[{"instruction": "Mountain fold along anti-diagonal", "from": [1, 0], "to": [0, 1], "assignment": "M"}]</folds>',
31
- "thirds_h": '<folds>[{"instruction": "Valley fold at one-third height", "from": [0, 0.333], "to": [1, 0.333], "assignment": "V"}, {"instruction": "Valley fold at two-thirds height", "from": [0, 0.667], "to": [1, 0.667], "assignment": "V"}]</folds>',
32
- "thirds_v": '<folds>[{"instruction": "Mountain fold at one-third width", "from": [0.333, 0], "to": [0.333, 1], "assignment": "M"}, {"instruction": "Mountain fold at two-thirds width", "from": [0.667, 0], "to": [0.667, 1], "assignment": "M"}]</folds>',
33
- "accordion_3h": '<folds>[{"instruction": "Valley fold at quarter height", "from": [0, 0.25], "to": [1, 0.25], "assignment": "V"}, {"instruction": "Mountain fold at half height", "from": [0, 0.5], "to": [1, 0.5], "assignment": "M"}, {"instruction": "Valley fold at three-quarter height", "from": [0, 0.75], "to": [1, 0.75], "assignment": "V"}]</folds>',
34
- "accordion_4h": '<folds>[{"instruction": "Valley fold at 0.2", "from": [0, 0.2], "to": [1, 0.2], "assignment": "V"}, {"instruction": "Mountain fold at 0.4", "from": [0, 0.4], "to": [1, 0.4], "assignment": "M"}, {"instruction": "Valley fold at 0.6", "from": [0, 0.6], "to": [1, 0.6], "assignment": "V"}, {"instruction": "Mountain fold at 0.8", "from": [0, 0.8], "to": [1, 0.8], "assignment": "M"}]</folds>',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  }
36
 
37
 
 
 
 
 
38
  @app.get("/targets", include_in_schema=True)
39
  def get_targets() -> dict:
40
- """Return available target names and metadata for the frontend."""
41
- from env.environment import OrigamiEnvironment
42
 
43
- env = OrigamiEnvironment()
44
  result: dict[str, dict] = {}
45
- for name in env.available_targets():
46
- t = env._targets[name]
47
  result[name] = {
48
  "name": name,
49
- "level": t.get("level", 1),
50
  "description": t.get("description", ""),
51
- "n_creases": sum(1 for a in t["edges_assignment"] if a in ("M", "V")),
 
 
52
  }
53
  return result
54
 
55
 
56
- @app.get("/episode/run", include_in_schema=True)
57
- def run_episode(target: str = "half_horizontal", completion: str = "") -> dict:
58
- """Run a fold-sequence episode and return step-by-step data."""
59
- from env.environment import OrigamiEnvironment
60
- from env.prompts import parse_fold_list, step_level_prompt
61
- from env.rewards import compute_reward
62
-
63
- env = OrigamiEnvironment(mode="step")
64
- obs = env.reset(target_name=target)
65
 
66
- if not completion:
67
- return {"prompt": obs["prompt"], "steps": [], "target": env.target}
68
 
69
- try:
70
- folds = parse_fold_list(completion)
71
- except ValueError as exc:
72
- return {"error": str(exc), "steps": []}
73
 
74
  steps: list[dict] = []
75
- for i, fold in enumerate(folds):
76
- result = env.paper.add_crease(fold["from"], fold["to"], fold["assignment"])
77
- reward = compute_reward(env.paper, result, env.target)
78
-
79
- paper_state = {
80
- "vertices": {str(k): list(v) for k, v in env.paper.graph.vertices.items()},
81
- "edges": [
82
- {
83
- "id": k,
84
- "v1": list(env.paper.graph.vertices[v[0]]),
85
- "v2": list(env.paper.graph.vertices[v[1]]),
86
- "assignment": v[2],
87
- }
88
- for k, v in env.paper.graph.edges.items()
89
- ],
90
- "anchor_points": [list(p) for p in env.paper.anchor_points()],
91
- }
92
 
93
- step_prompt = step_level_prompt(
94
- target=env.target,
95
- paper_state=env.paper,
96
- step=i + 1,
97
- max_steps=env.max_steps,
98
- last_reward=reward,
99
- )
100
 
101
- steps.append(
102
- {
103
- "step": i + 1,
104
- "fold": {
105
- "from_point": fold["from"],
106
- "to_point": fold["to"],
107
- "assignment": fold["assignment"],
108
- "instruction": fold.get("instruction", ""),
109
- },
110
- "paper_state": paper_state,
111
- "anchor_points": [list(p) for p in env.paper.anchor_points()],
112
- "reward": reward,
113
- "done": reward.get("completion", 0) > 0,
114
- "info": env._info(),
115
- "prompt": step_prompt,
116
- }
117
  )
118
 
119
- if reward.get("completion", 0) > 0:
 
 
 
 
 
 
 
 
 
 
120
  break
121
 
 
 
122
  return {
123
- "target_name": target,
124
- "target": env.target,
125
  "steps": steps,
126
- "final_reward": steps[-1]["reward"] if steps else {},
127
  }
128
 
129
 
130
- @app.get("/episode/demo", include_in_schema=True)
131
- def demo_episode(target: str = "half_horizontal") -> dict:
132
- """Return a pre-solved demo episode for the given target."""
133
- completion = DEMO_COMPLETIONS.get(target, DEMO_COMPLETIONS["half_horizontal"])
134
- return run_episode(target=target, completion=completion)
135
-
136
-
137
  # ---------------------------------------------------------------------------
138
- # Static file serving — must come LAST so API routes take priority.
139
  # ---------------------------------------------------------------------------
140
 
141
  _BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
 
19
 
20
 
21
  # ---------------------------------------------------------------------------
22
+ # Demo fold sequences new format: type, line {start, end}, angle
 
23
  # ---------------------------------------------------------------------------
24
 
25
+ DEMO_SEQUENCES: dict[str, list[dict]] = {
26
+ "half_fold": [
27
+ {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
28
+ ],
29
+ "quarter_fold": [
30
+ {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
31
+ {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
32
+ ],
33
+ "letter_fold": [
34
+ {"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
35
+ {"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0},
36
+ ],
37
+ "map_fold": [
38
+ {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
39
+ {"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0},
40
+ ],
41
+ "solar_panel": [
42
+ {"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
43
+ {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
44
+ {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0},
45
+ ],
46
+ "shelter_wall": [
47
+ {"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
48
+ {"type": "valley", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0},
49
+ ],
50
+ "stent": [
51
+ {"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 90.0},
52
+ {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 90.0},
53
+ {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 90.0},
54
+ {"type": "stop", "line": {"start": [0.0, 0.0], "end": [1.0, 1.0]}, "angle": 0.0},
55
+ ],
56
  }
57
 
58
 
59
+ # ---------------------------------------------------------------------------
60
+ # API routes — must be registered BEFORE the StaticFiles catch-all mount
61
+ # ---------------------------------------------------------------------------
62
+
63
  @app.get("/targets", include_in_schema=True)
64
  def get_targets() -> dict:
65
+ """Return available task names and metadata for the frontend."""
66
+ from server.tasks import get_task_by_name, available_task_names
67
 
 
68
  result: dict[str, dict] = {}
69
+ for name in available_task_names():
70
+ t = get_task_by_name(name)
71
  result[name] = {
72
  "name": name,
73
+ "level": t.get("difficulty", 1),
74
  "description": t.get("description", ""),
75
+ "n_creases": t.get("max_folds", 3),
76
+ "difficulty": t.get("difficulty", 1),
77
+ "material": t.get("material", "paper"),
78
  }
79
  return result
80
 
81
 
82
+ @app.get("/episode/demo", include_in_schema=True)
83
+ def demo_episode(target: str = "half_fold") -> dict:
84
+ """Return a pre-solved demo episode for the given task."""
85
+ from server.origami_environment import OrigamiEnvironment
86
+ from server.models import OrigamiAction as NewOrigamiAction
87
+ from server.tasks import get_task_by_name
 
 
 
88
 
89
+ # Fall back to half_fold if target not found
90
+ folds = DEMO_SEQUENCES.get(target, DEMO_SEQUENCES["half_fold"])
91
 
92
+ env = OrigamiEnvironment()
93
+ obs = env.reset(task_name=target)
 
 
94
 
95
  steps: list[dict] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ for i, fold_dict in enumerate(folds):
98
+ if fold_dict.get("type") == "stop":
99
+ break
 
 
 
 
100
 
101
+ action = NewOrigamiAction(
102
+ fold_type=fold_dict["type"],
103
+ fold_line=fold_dict["line"],
104
+ fold_angle=float(fold_dict.get("angle", 180.0)),
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
 
107
+ obs = env.step(action)
108
+
109
+ steps.append({
110
+ "step": i + 1,
111
+ "fold": fold_dict,
112
+ "paper_state": obs.paper_state,
113
+ "metrics": obs.metrics,
114
+ "done": obs.done,
115
+ })
116
+
117
+ if obs.done:
118
  break
119
 
120
+ task_def = get_task_by_name(target) if target else {}
121
+
122
  return {
123
+ "task_name": target,
124
+ "task": task_def,
125
  "steps": steps,
126
+ "final_metrics": obs.metrics if steps else {},
127
  }
128
 
129
 
 
 
 
 
 
 
 
130
  # ---------------------------------------------------------------------------
131
+ # Static file serving — must come LAST so API routes take priority
132
  # ---------------------------------------------------------------------------
133
 
134
  _BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/app.py — Training WebSocket server for Colab environment.
3
+
4
+ Provides /ws/training for live streaming of RL training episodes to browsers.
5
+ Mount at a publicly accessible URL in Colab (e.g., via ngrok or Colab's proxy).
6
+
7
+ Usage in training:
8
+ from server.app import broadcast
9
+ broadcast.publish(episode_id, {"type": "episode_update", ...})
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException, WebSocket
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import HTMLResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+
21
+ from server.training_broadcast import TrainingBroadcastServer
22
+
23
+ app = FastAPI(title="Optigami Training Server", version="1.0")
24
+
25
+ # Allow cross-origin connections (Colab public URL → browser)
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Global broadcast server — import and use from training code
35
+ broadcast = TrainingBroadcastServer()
36
+
37
+
38
+ @app.on_event("startup")
39
+ async def _store_loop() -> None:
40
+ """Capture the asyncio event loop so training threads can schedule coroutines."""
41
+ import asyncio
42
+ broadcast._loop = asyncio.get_running_loop()
43
+
44
+
45
+ @app.websocket("/ws/training")
46
+ async def training_ws(websocket: WebSocket) -> None:
47
+ """Spectator WebSocket endpoint. Viewers connect here to watch training."""
48
+ await broadcast.connect_spectator(websocket)
49
+
50
+
51
+ @app.get("/health")
52
+ def health() -> dict:
53
+ return {
54
+ "status": "ok",
55
+ "spectators": broadcast.spectator_count,
56
+ "active_episodes": broadcast.active_episodes,
57
+ }
58
+
59
+
60
+ # ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
61
+
62
+ @app.get("/targets")
63
+ def get_targets() -> dict:
64
+ from server.tasks import available_task_names, get_task_by_name
65
+ return {
66
+ name: {
67
+ "name": name,
68
+ "level": t["difficulty"],
69
+ "description": t.get("description", ""),
70
+ "n_creases": t.get("max_folds", 3),
71
+ "difficulty": t["difficulty"],
72
+ "material": t.get("material", "paper"),
73
+ }
74
+ for name in available_task_names()
75
+ if (t := get_task_by_name(name))
76
+ }
77
+
78
+
79
+ _DEMO_SEQUENCES: dict[str, list[dict]] = {
80
+ "half_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}],
81
+ "quarter_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
82
+ {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
83
+ "letter_fold": [{"type": "valley", "line": {"start": [0.0, 0.333], "end": [1.0, 0.333]}, "angle": 180.0},
84
+ {"type": "mountain", "line": {"start": [0.0, 0.667], "end": [1.0, 0.667]}, "angle": 180.0}],
85
+ "map_fold": [{"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
86
+ {"type": "mountain", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}],
87
+ "solar_panel": [{"type": "valley", "line": {"start": [0.0, 0.25], "end": [1.0, 0.25]}, "angle": 180.0},
88
+ {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0},
89
+ {"type": "valley", "line": {"start": [0.0, 0.75], "end": [1.0, 0.75]}, "angle": 180.0}],
90
+ }
91
+
92
+
93
+ @app.get("/episode/demo")
94
+ def demo_episode(target: str = "half_fold") -> dict:
95
+ from server.origami_environment import OrigamiEnvironment
96
+ from server.models import OrigamiAction as NewAction
97
+ from server.tasks import get_task_by_name
98
+
99
+ folds = _DEMO_SEQUENCES.get(target, _DEMO_SEQUENCES["half_fold"])
100
+ env = OrigamiEnvironment()
101
+ obs = env.reset(task_name=target)
102
+ steps: list[dict] = []
103
+
104
+ for i, fold_dict in enumerate(folds):
105
+ action = NewAction(
106
+ fold_type=fold_dict["type"],
107
+ fold_line=fold_dict["line"],
108
+ fold_angle=float(fold_dict.get("angle", 180.0)),
109
+ )
110
+ obs = env.step(action)
111
+ steps.append({"step": i + 1, "fold": fold_dict,
112
+ "paper_state": obs.paper_state, "metrics": obs.metrics,
113
+ "done": obs.done})
114
+ if obs.done:
115
+ break
116
+
117
+ return {"task_name": target, "task": get_task_by_name(target) or {},
118
+ "steps": steps, "final_metrics": obs.metrics if steps else {}}
119
+
120
+
121
+ @app.get("/episode/replay/{ep_id}")
122
+ def replay_episode(ep_id: str) -> dict:
123
+ """Return a stored training episode in the same format as /episode/demo."""
124
+ from server.tasks import get_task_by_name
125
+ ep = broadcast._registry.get(ep_id)
126
+ if not ep:
127
+ raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
128
+ return {
129
+ "task_name": ep.task_name,
130
+ "task": get_task_by_name(ep.task_name) or {},
131
+ "steps": ep.steps,
132
+ "final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
133
+ }
134
+
135
+
136
+ # ── Static files — viewer first, then React app (LAST, catch-all) ──
137
+
138
+ _VIEWER_DIR = Path(__file__).resolve().parent.parent / "viewer"
139
+ _BUILD_DIR = Path(__file__).resolve().parent.parent / "build"
140
+
141
+ if _VIEWER_DIR.exists():
142
+ app.mount("/viewer", StaticFiles(directory=str(_VIEWER_DIR), html=True), name="viewer")
143
+
144
+
145
+ if _BUILD_DIR.exists():
146
+ app.mount("/", StaticFiles(directory=str(_BUILD_DIR), html=True), name="react")
147
+ else:
148
+ @app.get("/", include_in_schema=False)
149
+ def _no_build() -> HTMLResponse:
150
+ return HTMLResponse(
151
+ "<p>React build not found. Run <code>npm run build</code> in the frontend directory.</p>"
152
+ "<p>Training viewer: <a href='/viewer/training.html'>/viewer/training.html</a></p>"
153
+ )
154
+
155
+
156
+ def run(host: str = "0.0.0.0", port: int = 9001) -> None:
157
+ """Start the training server. Call from Colab notebook."""
158
+ uvicorn.run(app, host=host, port=port)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ run()
server/models.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv Pydantic models for the origami RL environment.
3
+
4
+ OrigamiAction — one fold per step
5
+ OrigamiObservation — everything the LLM and Three.js viewer need
6
+ OrigamiState — server-side episode tracking
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import Field
13
+
14
+ from openenv.core.env_server.types import Action, Observation, State
15
+
16
+
17
+ class OrigamiAction(Action):
18
+ """One fold operation sent by the client each step."""
19
+
20
+ fold_type: str = Field(
21
+ default="valley",
22
+ description="'valley' | 'mountain' | 'pleat' | 'crimp' | 'stop'",
23
+ )
24
+ fold_line: dict[str, list[float]] = Field(
25
+ default_factory=lambda: {"start": [0.0, 0.5], "end": [1.0, 0.5]},
26
+ description="{'start': [x, y], 'end': [x, y]} normalized 0-1",
27
+ )
28
+ fold_angle: float = Field(
29
+ default=180.0,
30
+ description="Fold angle in degrees, 0-180",
31
+ )
32
+ layer_select: str = Field(
33
+ default="all",
34
+ description="'all' | 'top' | 'bottom'",
35
+ )
36
+
37
+
38
+ class OrigamiObservation(Observation):
39
+ """Everything the LLM and Three.js viewer need.
40
+
41
+ paper_state contains FOLD-compatible geometry + physics data.
42
+ metrics contains all computed quality metrics.
43
+ No render_urls — the browser renders from paper_state directly.
44
+ """
45
+
46
+ task: dict[str, Any] = Field(default_factory=dict)
47
+ paper_state: dict[str, Any] = Field(default_factory=dict)
48
+ metrics: dict[str, Any] = Field(default_factory=dict)
49
+ fold_history: list[dict[str, Any]] = Field(default_factory=list)
50
+ error: Optional[str] = Field(default=None)
51
+
52
+
53
+ class OrigamiState(State):
54
+ """Server-side episode tracking."""
55
+
56
+ task_name: str = Field(default="")
57
+ num_folds_applied: int = Field(default=0)
58
+ is_valid: bool = Field(default=True)
59
+ total_reward: float = Field(default=0.0)
server/origami_environment.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OrigamiEnvironment — OpenEnv environment wrapping the origami physics engine.
3
+
4
+ Implements reset() / step() / state following the OpenEnv interface.
5
+ Engine (physics, fold, validation, metrics) lives in engine/.
6
+ No server-side image rendering — paper_state contains all geometry data.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import uuid
13
+ from typing import Any, Optional
14
+
15
+ from openenv.core.env_server.interfaces import Environment
16
+
17
+ from engine.paper import Paper
18
+ from engine.fold_engine import apply_fold
19
+ from engine.physics import simulate
20
+ from engine.validation import validate_state
21
+ from engine.metrics import compute_all_metrics
22
+ from server.models import OrigamiAction, OrigamiObservation, OrigamiState
23
+ from server.tasks import get_task_by_name, sample_task
24
+
25
+
26
+ def _get_material(name: str):
27
+ """Get material by name, falling back to paper."""
28
+ try:
29
+ from engine.materials import get_material
30
+ return get_material(name)
31
+ except Exception:
32
+ from engine.materials import get_material
33
+ return get_material("paper")
34
+
35
+
36
+ class OrigamiEnvironment(Environment[OrigamiAction, OrigamiObservation, OrigamiState]):
37
+ """Origami folding RL environment.
38
+
39
+ Each episode: agent receives paper_state + task, applies folds one at a
40
+ time via step(), receives metrics + reward, ends with 'stop' action or
41
+ when max_folds is reached.
42
+ """
43
+
44
+ SUPPORTS_CONCURRENT_SESSIONS = False
45
+
46
+ def __init__(self, **kwargs):
47
+ super().__init__(**kwargs)
48
+ self._paper: Optional[Paper] = None
49
+ self._task: Optional[dict] = None
50
+ self._fold_history: list[dict] = []
51
+ self._metrics: dict = {}
52
+ self._validation: dict = {}
53
+ self._error: Optional[str] = None
54
+ self._episode_id: Optional[str] = None
55
+ self._step_count: int = 0
56
+ self._total_reward: float = 0.0
57
+
58
+ # ── reset ─────────────────────────────────────────────────────────
59
+
60
+ def reset(
61
+ self,
62
+ seed: Optional[int] = None,
63
+ episode_id: Optional[str] = None,
64
+ **kwargs: Any,
65
+ ) -> OrigamiObservation:
66
+ self._episode_id = episode_id or str(uuid.uuid4())
67
+ self._step_count = 0
68
+ self._fold_history = []
69
+ self._error = None
70
+ self._total_reward = 0.0
71
+
72
+ # Select task
73
+ task_name = kwargs.get("task_name")
74
+ if task_name:
75
+ self._task = get_task_by_name(task_name)
76
+ if not self._task:
77
+ self._task = sample_task(seed=seed)
78
+
79
+ # Create flat sheet
80
+ mat = _get_material(self._task["material"])
81
+ self._paper = Paper.create_flat_sheet(
82
+ width=self._task["width"],
83
+ height=self._task["height"],
84
+ material=mat,
85
+ )
86
+
87
+ # Initial validation + metrics (no physics needed for flat sheet)
88
+ self._validation = validate_state(self._paper)
89
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
90
+
91
+ return self._make_observation(done=False, reward=None)
92
+
93
+ # ── step ──────────────────────────────────────────────────────────
94
+
95
+ def step(
96
+ self,
97
+ action: OrigamiAction,
98
+ timeout_s: Optional[float] = None,
99
+ **kwargs: Any,
100
+ ) -> OrigamiObservation:
101
+ if self._paper is None or self._task is None:
102
+ return self._make_observation(done=True, reward=-5.0)
103
+
104
+ self._step_count += 1
105
+ self._error = None
106
+
107
+ # ── Stop action ───────────────────────────────────────────────
108
+ if action.fold_type == "stop":
109
+ return self._finalize_episode()
110
+
111
+ # ── Build fold dict ───────────────────────────────────────────
112
+ fold_dict = {
113
+ "type": action.fold_type,
114
+ "line": action.fold_line,
115
+ "angle": action.fold_angle,
116
+ }
117
+
118
+ # ── Apply fold ────────────────────────────────────────────────
119
+ new_paper, err = apply_fold(self._paper, fold_dict)
120
+ if err:
121
+ self._error = err
122
+ return self._make_observation(done=True, reward=-5.0)
123
+
124
+ self._paper = new_paper
125
+ self._fold_history.append({**fold_dict, "step": self._step_count})
126
+
127
+ # ── Physics relaxation ────────────────────────────────────────
128
+ try:
129
+ self._paper = simulate(self._paper, fold_percent=1.0)
130
+ except Exception as exc:
131
+ self._error = f"Physics failed: {exc}"
132
+ # Continue — don't abort episode on physics failure
133
+
134
+ # ── Validate ──────────────────────────────────────────────────
135
+ self._validation = validate_state(self._paper)
136
+
137
+ # ── Metrics ───────────────────────────────────────────────────
138
+ self._metrics = compute_all_metrics(self._paper, self._task, self._validation)
139
+
140
+ # ── Check termination ─────────────────────────────────────────
141
+ max_folds = self._task.get("max_folds", 50)
142
+ if self._step_count >= max_folds:
143
+ return self._finalize_episode()
144
+
145
+ if self._validation.get("self_intersections", 0) > 0:
146
+ self._error = "Self-intersection detected"
147
+ return self._finalize_episode()
148
+
149
+ return self._make_observation(done=False, reward=None)
150
+
151
+ # ── state ─────────────────────────────────────────────────────────
152
+
153
+ @property
154
+ def state(self) -> OrigamiState:
155
+ return OrigamiState(
156
+ episode_id=self._episode_id,
157
+ step_count=self._step_count,
158
+ task_name=self._task.get("name", "") if self._task else "",
159
+ num_folds_applied=len(self._fold_history),
160
+ is_valid=self._metrics.get("is_valid", True),
161
+ total_reward=self._total_reward,
162
+ )
163
+
164
+ # ── internals ─────────────────────────────────────────────────────
165
+
166
+ def _finalize_episode(self) -> OrigamiObservation:
167
+ reward = self._compute_reward()
168
+ self._total_reward = reward
169
+ return self._make_observation(done=True, reward=reward)
170
+
171
+ def _make_observation(self, done: bool, reward: Optional[float]) -> OrigamiObservation:
172
+ return OrigamiObservation(
173
+ done=done,
174
+ reward=reward,
175
+ task=self._task or {},
176
+ paper_state=self._paper.to_observation_dict() if self._paper else {},
177
+ metrics=self._metrics,
178
+ fold_history=self._fold_history,
179
+ error=self._error,
180
+ )
181
+
182
+ def _compute_reward(self) -> float:
183
+ m = self._metrics
184
+ reward = 0.0
185
+
186
+ # Compactness is the main signal
187
+ reward += m.get("compactness", 0.0) * 20.0
188
+
189
+ # Bonus for fitting in target box
190
+ if m.get("fits_target_box", False):
191
+ reward += 10.0
192
+
193
+ # Bonus for deployability (if task requires it)
194
+ if m.get("is_deployable", False):
195
+ reward += 5.0
196
+
197
+ # Penalties for violations
198
+ reward -= m.get("kawasaki_violations", 0) * 2.0
199
+ reward -= m.get("maekawa_violations", 0) * 2.0
200
+ reward -= m.get("self_intersections", 0) * 5.0
201
+
202
+ # Penalty for too many folds (encourage efficiency)
203
+ reward -= m.get("fold_count", 0) * 0.5
204
+
205
+ # Penalty for exceeding material strain limit
206
+ max_strain = m.get("max_strain", 0.0)
207
+ strain_limit = self._paper.material.max_strain if self._paper else 0.05
208
+ if max_strain > strain_limit:
209
+ reward -= 3.0 * (max_strain / strain_limit)
210
+
211
+ return float(reward)
server/tasks.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task pool and curriculum for the origami RL environment.
3
+
4
+ 7 tasks across 4 difficulty levels.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ from typing import Optional
10
+
11
+
12
+ TASKS: dict[str, dict] = {
13
+ "half_fold": {
14
+ "name": "half_fold",
15
+ "description": "Fold a 1x1 paper sheet in half along the horizontal midline.",
16
+ "width": 1.0,
17
+ "height": 1.0,
18
+ "material": "paper",
19
+ "target_ratio": 0.50,
20
+ "max_folds": 3,
21
+ "target_box": [1.0, 0.5, 0.02],
22
+ "must_deploy": False,
23
+ "difficulty": 1,
24
+ },
25
+ "quarter_fold": {
26
+ "name": "quarter_fold",
27
+ "description": "Fold a 1x1 paper sheet into quarters using two perpendicular folds.",
28
+ "width": 1.0,
29
+ "height": 1.0,
30
+ "material": "paper",
31
+ "target_ratio": 0.25,
32
+ "max_folds": 5,
33
+ "target_box": [0.5, 0.5, 0.04],
34
+ "must_deploy": False,
35
+ "difficulty": 1,
36
+ },
37
+ "letter_fold": {
38
+ "name": "letter_fold",
39
+ "description": "Fold a 1x1 paper into thirds (letter fold) using two parallel folds.",
40
+ "width": 1.0,
41
+ "height": 1.0,
42
+ "material": "paper",
43
+ "target_ratio": 0.33,
44
+ "max_folds": 5,
45
+ "target_box": [1.0, 0.34, 0.03],
46
+ "must_deploy": False,
47
+ "difficulty": 2,
48
+ },
49
+ "map_fold": {
50
+ "name": "map_fold",
51
+ "description": "Fold a 1x1 paper into eighths using a grid fold pattern. Must be re-deployable.",
52
+ "width": 1.0,
53
+ "height": 1.0,
54
+ "material": "paper",
55
+ "target_ratio": 0.125,
56
+ "max_folds": 8,
57
+ "target_box": [0.5, 0.25, 0.08],
58
+ "must_deploy": True,
59
+ "difficulty": 2,
60
+ },
61
+ "solar_panel": {
62
+ "name": "solar_panel",
63
+ "description": "Pack a 1x1 Mylar solar panel into a compact configuration using a Miura-ori style fold. Must deploy.",
64
+ "width": 1.0,
65
+ "height": 1.0,
66
+ "material": "mylar",
67
+ "target_ratio": 0.05,
68
+ "max_folds": 20,
69
+ "target_box": [0.25, 0.25, 0.05],
70
+ "must_deploy": True,
71
+ "difficulty": 3,
72
+ },
73
+ "shelter_wall": {
74
+ "name": "shelter_wall",
75
+ "description": "Fold a 1x1 aluminum sheet into a compact structural panel within strain limits.",
76
+ "width": 1.0,
77
+ "height": 1.0,
78
+ "material": "aluminum",
79
+ "target_ratio": 0.10,
80
+ "max_folds": 15,
81
+ "target_box": [0.5, 0.25, 0.1],
82
+ "must_deploy": False,
83
+ "difficulty": 3,
84
+ },
85
+ "stent": {
86
+ "name": "stent",
87
+ "description": "Fold a 0.5x1.5 nitinol sheet into a compact tube configuration for a medical stent. Superelastic material.",
88
+ "width": 0.5,
89
+ "height": 1.5,
90
+ "material": "nitinol",
91
+ "target_ratio": 0.09,
92
+ "max_folds": 25,
93
+ "target_box": [0.1, 0.1, 0.15],
94
+ "must_deploy": True,
95
+ "difficulty": 4,
96
+ },
97
+ }
98
+
99
+
100
+ def get_task_by_name(name: str) -> Optional[dict]:
101
+ """Return task dict by name, or None if not found."""
102
+ return TASKS.get(name)
103
+
104
+
105
+ def sample_task(seed: Optional[int] = None, difficulty: Optional[int] = None) -> dict:
106
+ """Sample a random task, optionally filtered by difficulty level."""
107
+ rng = random.Random(seed)
108
+ pool = list(TASKS.values())
109
+ if difficulty is not None:
110
+ pool = [t for t in pool if t["difficulty"] == difficulty]
111
+ if not pool:
112
+ pool = list(TASKS.values())
113
+ return dict(rng.choice(pool))
114
+
115
+
116
+ def get_tasks_by_difficulty(level: int) -> list[dict]:
117
+ """Return all tasks at a given difficulty level."""
118
+ return [dict(t) for t in TASKS.values() if t["difficulty"] == level]
119
+
120
+
121
+ def available_task_names() -> list[str]:
122
+ """Return sorted list of all task names."""
123
+ return sorted(TASKS.keys())
server/training_broadcast.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TrainingBroadcastServer — fire-and-forget broadcast hub for live training viewer.
3
+
4
+ The RL training process calls publish() after each env.step().
5
+ Spectator browsers connect via /ws/training WebSocket.
6
+ Broadcast is async and non-blocking: if no viewers are connected, observations are dropped.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Optional
15
+
16
+ from fastapi import WebSocket, WebSocketDisconnect
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class EpisodeInfo:
23
+ episode_id: str
24
+ task_name: str
25
+ status: str = "running" # "running" | "done" | "timeout" | "error"
26
+ step: int = 0
27
+ observation: dict = field(default_factory=dict)
28
+ metrics: dict = field(default_factory=dict)
29
+ fold_history: list = field(default_factory=list)
30
+ steps: list = field(default_factory=list) # full step history for replay
31
+ score: Optional[float] = None
32
+ final_metrics: Optional[dict] = None
33
+
34
+
35
+ class TrainingBroadcastServer:
36
+ """Central hub for broadcasting RL training observations to spectator WebSockets.
37
+
38
+ Thread-safe: publish() can be called from training threads (ThreadPoolExecutor).
39
+ WebSocket handlers run in the asyncio event loop.
40
+ """
41
+
42
+ def __init__(self) -> None:
43
+ self._spectators: list[WebSocket] = []
44
+ self._registry: dict[str, EpisodeInfo] = {}
45
+ self._batch_id: int = 0
46
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
47
+ self._lock = asyncio.Lock()
48
+
49
+ # ── Episode publishing (called from training thread / async context) ──
50
+
51
+ def publish(self, episode_id: str, data: dict) -> None:
52
+ """Fire-and-forget: push an update from the training process.
53
+
54
+ Safe to call from any thread. Schedules onto the stored event loop
55
+ (set by the FastAPI startup handler). No-op if no loop is available.
56
+ """
57
+ loop = self._loop
58
+ if loop is None or loop.is_closed():
59
+ return
60
+ asyncio.run_coroutine_threadsafe(self._async_publish(episode_id, data), loop)
61
+
62
+ async def _async_publish(self, episode_id: str, data: dict) -> None:
63
+ msg_type = data.get("type", "episode_update")
64
+
65
+ async with self._lock:
66
+ if msg_type == "batch_start":
67
+ self._batch_id = data.get("batch_id", self._batch_id + 1)
68
+ self._registry.clear()
69
+ await self._broadcast(data)
70
+ return
71
+
72
+ if msg_type == "batch_done":
73
+ await self._broadcast(data)
74
+ return
75
+
76
+ if msg_type == "training_done":
77
+ await self._broadcast(data)
78
+ return
79
+
80
+ # episode_update or episode_done
81
+ ep = self._registry.setdefault(
82
+ episode_id,
83
+ EpisodeInfo(episode_id=episode_id, task_name=data.get("task_name", "")),
84
+ )
85
+
86
+ if msg_type == "episode_done":
87
+ ep.status = data.get("status", "done")
88
+ ep.score = data.get("score")
89
+ ep.final_metrics = data.get("final_metrics")
90
+ else:
91
+ step_num = data.get("step", ep.step)
92
+ ep.step = step_num
93
+ ep.status = "running"
94
+ obs = data.get("observation", {})
95
+ ep.observation = obs
96
+ ep.metrics = obs.get("metrics", {})
97
+ ep.fold_history = obs.get("fold_history", ep.fold_history)
98
+ # Accumulate full step history for /episode/replay
99
+ if step_num > 0:
100
+ fold_hist = obs.get("fold_history", [])
101
+ latest_fold = fold_hist[-1] if fold_hist else {}
102
+ ep.steps.append({
103
+ "step": step_num,
104
+ "fold": latest_fold,
105
+ "paper_state": obs.get("paper_state", {}),
106
+ "metrics": obs.get("metrics", {}),
107
+ "done": obs.get("done", False),
108
+ })
109
+
110
+ await self._broadcast({"episode_id": episode_id, **data})
111
+
112
+ # ── Spectator management ──
113
+
114
+ async def connect_spectator(self, websocket: WebSocket) -> None:
115
+ """Accept a new viewer WebSocket and serve it until disconnect."""
116
+ await websocket.accept()
117
+
118
+ async with self._lock:
119
+ self._spectators.append(websocket)
120
+
121
+ # Send current registry snapshot immediately
122
+ await self._send_registry(websocket)
123
+
124
+ try:
125
+ while True:
126
+ # Viewers are read-only; drain any incoming messages (pings etc)
127
+ await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
128
+ except (WebSocketDisconnect, asyncio.TimeoutError, Exception):
129
+ pass
130
+ finally:
131
+ await self.disconnect_spectator(websocket)
132
+
133
+ async def disconnect_spectator(self, websocket: WebSocket) -> None:
134
+ async with self._lock:
135
+ self._spectators = [s for s in self._spectators if s is not websocket]
136
+
137
+ # ── Batch control ──
138
+
139
+ async def start_batch(self, batch_id: int, num_episodes: int, prompt_index: int = 0) -> None:
140
+ """Call before starting a new training batch."""
141
+ data = {
142
+ "type": "batch_start",
143
+ "batch_id": batch_id,
144
+ "num_episodes": num_episodes,
145
+ "prompt_index": prompt_index,
146
+ }
147
+ await self._async_publish("__batch__", data)
148
+
149
+ async def finish_batch(
150
+ self,
151
+ batch_id: int,
152
+ scores: list[float],
153
+ best_episode_id: str = "",
154
+ ) -> None:
155
+ """Call after all episodes in a batch complete."""
156
+ data = {
157
+ "type": "batch_done",
158
+ "batch_id": batch_id,
159
+ "scores": scores,
160
+ "best_episode_id": best_episode_id,
161
+ "avg_score": sum(scores) / len(scores) if scores else 0.0,
162
+ }
163
+ await self._async_publish("__batch__", data)
164
+
165
+ async def clear_batch(self) -> None:
166
+ """Reset episode registry for next batch."""
167
+ async with self._lock:
168
+ self._registry.clear()
169
+
170
+ # ── Internals ──
171
+
172
+ async def _broadcast(self, message: dict) -> None:
173
+ """Send message to all spectators, removing dead connections."""
174
+ if not self._spectators:
175
+ return
176
+ payload = json.dumps(message, default=str)
177
+ dead: list[WebSocket] = []
178
+ for ws in list(self._spectators):
179
+ try:
180
+ await ws.send_text(payload)
181
+ except Exception:
182
+ dead.append(ws)
183
+ for ws in dead:
184
+ self._spectators = [s for s in self._spectators if s is not ws]
185
+
186
+ async def _send_registry(self, websocket: WebSocket) -> None:
187
+ """Send the full episode registry to a newly connected viewer."""
188
+ async with self._lock:
189
+ episodes = {
190
+ ep_id: {
191
+ "status": ep.status,
192
+ "task": ep.task_name,
193
+ "step": ep.step,
194
+ "observation": ep.observation,
195
+ "metrics": ep.metrics,
196
+ "score": ep.score,
197
+ }
198
+ for ep_id, ep in self._registry.items()
199
+ }
200
+ payload = {
201
+ "type": "registry",
202
+ "batch_id": self._batch_id,
203
+ "episodes": episodes,
204
+ }
205
+ try:
206
+ await websocket.send_text(json.dumps(payload, default=str))
207
+ except Exception:
208
+ pass
209
+
210
+ @property
211
+ def spectator_count(self) -> int:
212
+ return len(self._spectators)
213
+
214
+ @property
215
+ def active_episodes(self) -> int:
216
+ return sum(1 for ep in self._registry.values() if ep.status == "running")
server.py → server_legacy.py RENAMED
File without changes
src/App.css CHANGED
@@ -67,6 +67,30 @@
67
  margin-left: auto;
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  .api-status {
71
  font-size: 11px;
72
  font-family: var(--font-display);
 
67
  margin-left: auto;
68
  }
69
 
70
+ .replay-badge {
71
+ font-size: 10px;
72
+ font-family: var(--font-display);
73
+ letter-spacing: 0.1em;
74
+ color: #38bdf8;
75
+ background: rgba(56, 189, 248, 0.1);
76
+ border: 1px solid rgba(56, 189, 248, 0.3);
77
+ padding: 3px 8px;
78
+ border-radius: 3px;
79
+ }
80
+
81
+ .back-to-grid-btn {
82
+ font-size: 10px;
83
+ font-family: var(--font-display);
84
+ letter-spacing: 0.08em;
85
+ color: #64748b;
86
+ background: transparent;
87
+ border: 1px solid #1e2a3a;
88
+ padding: 3px 10px;
89
+ border-radius: 3px;
90
+ cursor: pointer;
91
+ }
92
+ .back-to-grid-btn:hover { color: #e2e8f0; border-color: #64748b; }
93
+
94
  .api-status {
95
  font-size: 11px;
96
  font-family: var(--font-display);
src/App.js CHANGED
@@ -10,17 +10,22 @@ import Fold3DCanvas from './components/Fold3DCanvas';
10
 
11
  const API_BASE = '';
12
 
 
 
 
 
13
  function App() {
14
  const [targets, setTargets] = useState({});
15
- const [selectedTarget, setSelectedTarget] = useState('half_horizontal');
16
  const [episode, setEpisode] = useState(null);
17
  const [currentStep, setCurrentStep] = useState(0);
18
  const [playing, setPlaying] = useState(false);
19
- const [foldRenderMode, setFoldRenderMode] = useState('progressive'); // 'progressive' | 'final'
20
- const [apiStatus, setApiStatus] = useState('connecting'); // 'connecting' | 'ok' | 'err'
21
  const [episodeLoading, setEpisodeLoading] = useState(false);
22
  const intervalRef = useRef(null);
23
 
 
 
24
  const fetchTargets = useCallback(async () => {
25
  try {
26
  const res = await fetch(`${API_BASE}/targets`);
@@ -51,13 +56,35 @@ function App() {
51
  }
52
  }, []);
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  useEffect(() => {
55
  fetchTargets();
56
  }, [fetchTargets]);
57
 
58
  useEffect(() => {
59
- fetchDemoEpisode(selectedTarget);
60
- }, [selectedTarget, fetchDemoEpisode]);
 
 
 
 
61
 
62
  const totalSteps = episode ? episode.steps.length : 0;
63
 
@@ -99,7 +126,6 @@ function App() {
99
  };
100
 
101
  const targetDef = targets[selectedTarget] || null;
102
- const targetFold = episode ? episode.target : null;
103
 
104
  return (
105
  <div className="app">
@@ -108,11 +134,20 @@ function App() {
108
  OPTI<span className="title-accent">GAMI</span> RL
109
  </span>
110
  <div className="header-sep" />
111
- <TargetSelector
112
- targets={targets}
113
- selected={selectedTarget}
114
- onChange={name => setSelectedTarget(name)}
115
- />
 
 
 
 
 
 
 
 
 
116
  <div className="header-sep" />
117
  <PlayerControls
118
  playing={playing}
@@ -138,12 +173,12 @@ function App() {
138
  <div className="canvas-row">
139
  <div className="canvas-wrap">
140
  <span className="canvas-label">
141
- TARGET — {targetDef ? targetDef.name.replace(/_/g, ' ').toUpperCase() : '—'}
142
  </span>
143
  <CreaseCanvas
144
  paperState={null}
145
- target={targetFold}
146
- label="TARGET"
147
  dim={280}
148
  ghostOnly={true}
149
  />
@@ -154,7 +189,7 @@ function App() {
154
  </span>
155
  <CreaseCanvas
156
  paperState={activeStepData ? activeStepData.paper_state : null}
157
- target={targetFold}
158
  label={currentStep === 0 ? 'INITIAL' : `STEP ${currentStep}`}
159
  dim={280}
160
  ghostOnly={false}
@@ -163,28 +198,10 @@ function App() {
163
  <div className="canvas-wrap">
164
  <div className="canvas-label-row">
165
  <span className="canvas-label">3D FOLD PREVIEW</span>
166
- <div className="fold-mode-toggle">
167
- <button
168
- className={`fold-mode-btn${foldRenderMode === 'progressive' ? ' active' : ''}`}
169
- onClick={() => setFoldRenderMode('progressive')}
170
- type="button"
171
- >
172
- PER CREASE
173
- </button>
174
- <button
175
- className={`fold-mode-btn${foldRenderMode === 'final' ? ' active' : ''}`}
176
- onClick={() => setFoldRenderMode('final')}
177
- type="button"
178
- >
179
- FOLD AT END
180
- </button>
181
- </div>
182
  </div>
183
  <Fold3DCanvas
184
  steps={episode ? episode.steps : []}
185
  currentStep={currentStep}
186
- totalSteps={totalSteps}
187
- mode={foldRenderMode}
188
  dim={280}
189
  />
190
  </div>
@@ -207,10 +224,14 @@ function App() {
207
  </div>
208
 
209
  <div className="app-right">
210
- <div className="section-header">REWARD DECOMPOSITION</div>
211
- <RewardPanel reward={activeStepData ? activeStepData.reward : null} />
212
  <div className="section-header">EPISODE INFO</div>
213
- <InfoBadges info={activeStepData ? activeStepData.info : null} targetDef={targetDef} />
 
 
 
 
214
  </div>
215
  </div>
216
  </div>
 
10
 
11
  const API_BASE = '';
12
 
13
+ // Read ?ep=<episode_id> from URL — set when navigating from training grid
14
+ const _urlParams = new URLSearchParams(window.location.search);
15
+ const REPLAY_EP_ID = _urlParams.get('ep') || null;
16
+
17
  function App() {
18
  const [targets, setTargets] = useState({});
19
+ const [selectedTarget, setSelectedTarget] = useState('half_fold');
20
  const [episode, setEpisode] = useState(null);
21
  const [currentStep, setCurrentStep] = useState(0);
22
  const [playing, setPlaying] = useState(false);
23
+ const [apiStatus, setApiStatus] = useState('connecting');
 
24
  const [episodeLoading, setEpisodeLoading] = useState(false);
25
  const intervalRef = useRef(null);
26
 
27
+ const isReplayMode = REPLAY_EP_ID !== null;
28
+
29
  const fetchTargets = useCallback(async () => {
30
  try {
31
  const res = await fetch(`${API_BASE}/targets`);
 
56
  }
57
  }, []);
58
 
59
+ const fetchReplayEpisode = useCallback(async (epId) => {
60
+ setEpisodeLoading(true);
61
+ setPlaying(false);
62
+ setCurrentStep(0);
63
+ try {
64
+ const res = await fetch(`${API_BASE}/episode/replay/${epId}`);
65
+ if (!res.ok) throw new Error(`HTTP ${res.status}`);
66
+ const data = await res.json();
67
+ setEpisode(data);
68
+ setApiStatus('ok');
69
+ } catch {
70
+ setEpisode(null);
71
+ setApiStatus('err');
72
+ } finally {
73
+ setEpisodeLoading(false);
74
+ }
75
+ }, []);
76
+
77
  useEffect(() => {
78
  fetchTargets();
79
  }, [fetchTargets]);
80
 
81
  useEffect(() => {
82
+ if (isReplayMode) {
83
+ fetchReplayEpisode(REPLAY_EP_ID);
84
+ } else {
85
+ fetchDemoEpisode(selectedTarget);
86
+ }
87
+ }, [isReplayMode, selectedTarget, fetchDemoEpisode, fetchReplayEpisode]);
88
 
89
  const totalSteps = episode ? episode.steps.length : 0;
90
 
 
126
  };
127
 
128
  const targetDef = targets[selectedTarget] || null;
 
129
 
130
  return (
131
  <div className="app">
 
134
  OPTI<span className="title-accent">GAMI</span> RL
135
  </span>
136
  <div className="header-sep" />
137
+ {isReplayMode ? (
138
+ <>
139
+ <span className="replay-badge">REPLAY — {REPLAY_EP_ID}</span>
140
+ <button className="back-to-grid-btn" onClick={() => window.history.back()}>
141
+ ← GRID
142
+ </button>
143
+ </>
144
+ ) : (
145
+ <TargetSelector
146
+ targets={targets}
147
+ selected={selectedTarget}
148
+ onChange={name => setSelectedTarget(name)}
149
+ />
150
+ )}
151
  <div className="header-sep" />
152
  <PlayerControls
153
  playing={playing}
 
173
  <div className="canvas-row">
174
  <div className="canvas-wrap">
175
  <span className="canvas-label">
176
+ TASK — {targetDef ? targetDef.name.replace(/_/g, ' ').toUpperCase() : '—'}
177
  </span>
178
  <CreaseCanvas
179
  paperState={null}
180
+ target={null}
181
+ label="TASK"
182
  dim={280}
183
  ghostOnly={true}
184
  />
 
189
  </span>
190
  <CreaseCanvas
191
  paperState={activeStepData ? activeStepData.paper_state : null}
192
+ target={null}
193
  label={currentStep === 0 ? 'INITIAL' : `STEP ${currentStep}`}
194
  dim={280}
195
  ghostOnly={false}
 
198
  <div className="canvas-wrap">
199
  <div className="canvas-label-row">
200
  <span className="canvas-label">3D FOLD PREVIEW</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  </div>
202
  <Fold3DCanvas
203
  steps={episode ? episode.steps : []}
204
  currentStep={currentStep}
 
 
205
  dim={280}
206
  />
207
  </div>
 
224
  </div>
225
 
226
  <div className="app-right">
227
+ <div className="section-header">METRICS</div>
228
+ <RewardPanel metrics={activeStepData ? activeStepData.metrics : null} />
229
  <div className="section-header">EPISODE INFO</div>
230
+ <InfoBadges
231
+ metrics={activeStepData ? activeStepData.metrics : null}
232
+ paperState={activeStepData ? activeStepData.paper_state : null}
233
+ targetDef={targetDef}
234
+ />
235
  </div>
236
  </div>
237
  </div>
src/components/CreaseCanvas.js CHANGED
@@ -13,10 +13,11 @@ function GhostEdges({ target, dim }) {
13
  return edges_vertices.map((ev, i) => {
14
  const asgn = edges_assignment[i];
15
  if (asgn === 'B') return null;
16
- const [v1x, v1y] = vertices_coords[ev[0]];
17
- const [v2x, v2y] = vertices_coords[ev[1]];
18
- const [x1, y1] = toSvg(v1x, v1y, dim);
19
- const [x2, y2] = toSvg(v2x, v2y, dim);
 
20
  const color = asgn === 'M' ? MOUNTAIN : VALLEY;
21
  return (
22
  <line
@@ -32,15 +33,23 @@ function GhostEdges({ target, dim }) {
32
  }
33
 
34
  function CurrentEdges({ paperState, dim }) {
35
- if (!paperState || !paperState.edges) return null;
36
- return paperState.edges.map((edge) => {
37
- if (edge.assignment === 'B') return null;
38
- const [x1, y1] = toSvg(edge.v1[0], edge.v1[1], dim);
39
- const [x2, y2] = toSvg(edge.v2[0], edge.v2[1], dim);
40
- const color = edge.assignment === 'M' ? MOUNTAIN : VALLEY;
 
 
 
 
 
 
 
 
41
  return (
42
  <line
43
- key={edge.id}
44
  x1={x1} y1={y1} x2={x2} y2={y2}
45
  stroke={color}
46
  strokeWidth={2.5}
@@ -50,26 +59,6 @@ function CurrentEdges({ paperState, dim }) {
50
  });
51
  }
52
 
53
- function AnchorCrosses({ paperState, dim }) {
54
- if (!paperState || !paperState.anchor_points) return null;
55
- const size = 4;
56
- return paperState.anchor_points.map((pt, i) => {
57
- const [cx, cy] = toSvg(pt[0], pt[1], dim);
58
- return (
59
- <g key={i}>
60
- <line
61
- x1={cx - size} y1={cy} x2={cx + size} y2={cy}
62
- stroke="#64748b" strokeWidth={1}
63
- />
64
- <line
65
- x1={cx} y1={cy - size} x2={cx} y2={cy + size}
66
- stroke="#64748b" strokeWidth={1}
67
- />
68
- </g>
69
- );
70
- });
71
- }
72
-
73
  export default function CreaseCanvas({ paperState, target, dim = 280, ghostOnly = false }) {
74
  const pad = 1;
75
  const size = dim;
@@ -94,10 +83,7 @@ export default function CreaseCanvas({ paperState, target, dim = 280, ghostOnly
94
 
95
  {/* Current paper state */}
96
  {!ghostOnly && (
97
- <>
98
- <CurrentEdges paperState={paperState} dim={size} />
99
- <AnchorCrosses paperState={paperState} dim={size} />
100
- </>
101
  )}
102
 
103
  {/* Paper border */}
 
13
  return edges_vertices.map((ev, i) => {
14
  const asgn = edges_assignment[i];
15
  if (asgn === 'B') return null;
16
+ const v1 = vertices_coords[ev[0]];
17
+ const v2 = vertices_coords[ev[1]];
18
+ if (!v1 || !v2) return null;
19
+ const [x1, y1] = toSvg(v1[0], v1[1], dim);
20
+ const [x2, y2] = toSvg(v2[0], v2[1], dim);
21
  const color = asgn === 'M' ? MOUNTAIN : VALLEY;
22
  return (
23
  <line
 
33
  }
34
 
35
  function CurrentEdges({ paperState, dim }) {
36
+ if (!paperState) return null;
37
+ const { vertices_coords, edges_vertices, edges_assignment } = paperState;
38
+ if (!vertices_coords || !edges_vertices || !edges_assignment) return null;
39
+
40
+ return edges_vertices.map((ev, i) => {
41
+ const asgn = edges_assignment[i];
42
+ if (asgn === 'B' || asgn === 'F') return null;
43
+ const v1 = vertices_coords[ev[0]];
44
+ const v2 = vertices_coords[ev[1]];
45
+ if (!v1 || !v2) return null;
46
+ // vertices_coords are 3D [x, y, z] — use only x and y
47
+ const [x1, y1] = toSvg(v1[0], v1[1], dim);
48
+ const [x2, y2] = toSvg(v2[0], v2[1], dim);
49
+ const color = asgn === 'M' ? MOUNTAIN : VALLEY;
50
  return (
51
  <line
52
+ key={i}
53
  x1={x1} y1={y1} x2={x2} y2={y2}
54
  stroke={color}
55
  strokeWidth={2.5}
 
59
  });
60
  }
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  export default function CreaseCanvas({ paperState, target, dim = 280, ghostOnly = false }) {
63
  const pad = 1;
64
  const size = dim;
 
83
 
84
  {/* Current paper state */}
85
  {!ghostOnly && (
86
+ <CurrentEdges paperState={paperState} dim={size} />
 
 
 
87
  )}
88
 
89
  {/* Paper border */}
src/components/Fold3DCanvas.js CHANGED
@@ -1,11 +1,9 @@
1
- import { useCallback, useEffect, useMemo, useRef } from 'react';
2
 
3
  const PAPER_RGB = [250, 250, 245];
4
  const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
5
- const MAX_FOLD_RAD = Math.PI * 0.92;
6
- const SIDE_EPS = 1e-7;
7
- const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.95)';
8
- const VALLEY_COLOR = 'rgba(56, 189, 248, 0.95)';
9
 
10
  function clamp(value, min, max) {
11
  return Math.min(Math.max(value, min), max);
@@ -41,102 +39,23 @@ function shadePaper(intensity) {
41
  return `rgb(${r}, ${g}, ${b})`;
42
  }
43
 
44
- function buildGridMesh(resolution = 18) {
45
- const vertices = [];
46
- for (let y = 0; y <= resolution; y += 1) {
47
- for (let x = 0; x <= resolution; x += 1) {
48
- vertices.push([x / resolution, y / resolution, 0]);
49
- }
50
- }
51
-
52
- const triangles = [];
53
- const stride = resolution + 1;
54
- for (let y = 0; y < resolution; y += 1) {
55
- for (let x = 0; x < resolution; x += 1) {
56
- const a = y * stride + x;
57
- const b = a + 1;
58
- const c = a + stride;
59
- const d = c + 1;
60
- triangles.push([a, b, d]);
61
- triangles.push([a, d, c]);
62
- }
63
- }
64
-
65
- return { vertices, triangles, resolution };
66
- }
67
-
68
- function rotateAroundAxis(point, axisPoint, axisDir, angleRad) {
69
- const px = point[0] - axisPoint[0];
70
- const py = point[1] - axisPoint[1];
71
- const pz = point[2] - axisPoint[2];
72
-
73
- const kx = axisDir[0];
74
- const ky = axisDir[1];
75
- const kz = axisDir[2];
76
-
77
- const cosA = Math.cos(angleRad);
78
- const sinA = Math.sin(angleRad);
79
-
80
- const crossX = ky * pz - kz * py;
81
- const crossY = kz * px - kx * pz;
82
- const crossZ = kx * py - ky * px;
83
-
84
- const dot = px * kx + py * ky + pz * kz;
85
- const oneMinus = 1.0 - cosA;
86
-
87
- return [
88
- axisPoint[0] + px * cosA + crossX * sinA + kx * dot * oneMinus,
89
- axisPoint[1] + py * cosA + crossY * sinA + ky * dot * oneMinus,
90
- axisPoint[2] + pz * cosA + crossZ * sinA + kz * dot * oneMinus,
91
- ];
92
- }
93
-
94
- function applyFoldToVertices(vertices, fold, progress) {
95
- if (!fold || progress <= 0) return;
96
- const [x1, y1] = fold.from;
97
- const [x2, y2] = fold.to;
98
- const dx = x2 - x1;
99
- const dy = y2 - y1;
100
- const len = Math.hypot(dx, dy);
101
- if (len < 1e-8) return;
102
-
103
- const sideValues = [];
104
- let posCount = 0;
105
- let negCount = 0;
106
-
107
- for (let i = 0; i < vertices.length; i += 1) {
108
- const v = vertices[i];
109
- const side = dx * (v[1] - y1) - dy * (v[0] - x1);
110
- sideValues.push(side);
111
- if (side > SIDE_EPS) posCount += 1;
112
- else if (side < -SIDE_EPS) negCount += 1;
113
- }
114
-
115
- let rotatePositive = posCount <= negCount;
116
- if (posCount === 0 && negCount > 0) rotatePositive = false;
117
- if (negCount === 0 && posCount > 0) rotatePositive = true;
118
- if (posCount === 0 && negCount === 0) return;
119
-
120
- const sign = fold.assignment === 'V' ? 1 : -1;
121
- const angle = sign * MAX_FOLD_RAD * progress;
122
- const axisPoint = [x1, y1, 0];
123
- const axisDir = [dx / len, dy / len, 0];
124
-
125
- for (let i = 0; i < vertices.length; i += 1) {
126
- const side = sideValues[i];
127
- const shouldRotate = rotatePositive ? side > SIDE_EPS : side < -SIDE_EPS;
128
- if (!shouldRotate) continue;
129
- vertices[i] = rotateAroundAxis(vertices[i], axisPoint, axisDir, angle);
130
- }
131
  }
132
 
133
  function projectVertex(vertex, dim) {
134
  let x = vertex[0] - 0.5;
135
  let y = vertex[1] - 0.5;
136
- let z = vertex[2];
137
 
138
- const pitch = 1.04;
139
- const yaw = -0.78;
140
 
141
  const cp = Math.cos(pitch);
142
  const sp = Math.sin(pitch);
@@ -158,162 +77,119 @@ function projectVertex(vertex, dim) {
158
  };
159
  }
160
 
161
- function foldProgresses(stepValue, foldCount, mode, totalSteps) {
162
- const values = new Array(foldCount).fill(0);
163
- if (foldCount === 0) return values;
164
-
165
- if (mode === 'final') {
166
- const startCollapse = Math.max(totalSteps - 1, 0);
167
- const collapse = clamp(stepValue - startCollapse, 0, 1);
168
- for (let i = 0; i < foldCount; i += 1) values[i] = collapse;
169
- return values;
 
 
170
  }
171
 
172
- for (let i = 0; i < foldCount; i += 1) {
173
- if (stepValue >= i + 1) values[i] = 1;
174
- else if (stepValue > i) values[i] = clamp(stepValue - i, 0, 1);
175
- }
176
- return values;
177
- }
178
-
179
- function stepEasing(t) {
180
- return t < 0.5 ? 4 * t * t * t : 1 - ((-2 * t + 2) ** 3) / 2;
181
- }
182
-
183
- export default function Fold3DCanvas({
184
- steps,
185
- currentStep,
186
- totalSteps,
187
- mode = 'progressive',
188
- dim = 280,
189
- }) {
190
- const canvasRef = useRef(null);
191
- const rafRef = useRef(null);
192
- const animatedStepRef = useRef(currentStep);
193
-
194
- const folds = useMemo(
195
- () => (steps || [])
196
- .map((s) => s.fold)
197
- .filter(Boolean)
198
- .map((fold) => ({
199
- from: [Number(fold.from_point[0]), Number(fold.from_point[1])],
200
- to: [Number(fold.to_point[0]), Number(fold.to_point[1])],
201
- assignment: fold.assignment === 'M' ? 'M' : 'V',
202
- })),
203
- [steps],
204
- );
205
-
206
- const mesh = useMemo(() => buildGridMesh(18), []);
207
-
208
- const draw = useCallback((stepValue) => {
209
- const canvas = canvasRef.current;
210
- if (!canvas) return;
211
- const ctx = canvas.getContext('2d');
212
- if (!ctx) return;
213
 
214
- ctx.clearRect(0, 0, dim, dim);
215
  ctx.fillStyle = '#121220';
216
  ctx.fillRect(0, 0, dim, dim);
 
 
217
 
218
- const vertices = mesh.vertices.map((v) => [v[0], v[1], v[2]]);
219
- const progress = foldProgresses(stepValue, folds.length, mode, totalSteps);
220
-
221
- for (let i = 0; i < folds.length; i += 1) {
222
- if (progress[i] <= 0) continue;
223
- applyFoldToVertices(vertices, folds[i], progress[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  }
 
 
225
 
226
- const projected = vertices.map((v) => projectVertex(v, dim));
 
227
 
228
- const tris = mesh.triangles.map((tri) => {
229
- const p0 = projected[tri[0]];
230
- const p1 = projected[tri[1]];
231
- const p2 = projected[tri[2]];
 
 
 
 
 
232
  const avgZ = (p0.z + p1.z + p2.z) / 3;
233
-
234
- const v0 = vertices[tri[0]];
235
- const v1 = vertices[tri[1]];
236
- const v2 = vertices[tri[2]];
237
  const normal = normalize3(cross3(sub3(v1, v0), sub3(v2, v0)));
238
  const intensity = dot3(normal, LIGHT_DIR);
 
 
 
 
 
 
239
 
240
- return {
241
- tri,
242
- avgZ,
243
- shade: shadePaper(intensity),
244
- };
245
- });
246
 
247
- tris.sort((a, b) => a.avgZ - b.avgZ);
 
 
 
248
 
249
- for (const triInfo of tris) {
250
- const [a, b, c] = triInfo.tri;
251
- const p0 = projected[a];
252
- const p1 = projected[b];
253
- const p2 = projected[c];
254
-
255
- ctx.beginPath();
256
- ctx.moveTo(p0.x, p0.y);
257
- ctx.lineTo(p1.x, p1.y);
258
- ctx.lineTo(p2.x, p2.y);
259
- ctx.closePath();
260
- ctx.fillStyle = triInfo.shade;
261
- ctx.fill();
262
- ctx.strokeStyle = 'rgba(42, 42, 58, 0.22)';
263
- ctx.lineWidth = 0.55;
264
- ctx.stroke();
265
- }
266
 
267
- const res = mesh.resolution;
268
- const stride = res + 1;
269
- const pointToIndex = (pt) => {
270
- const ix = clamp(Math.round(pt[0] * res), 0, res);
271
- const iy = clamp(Math.round(pt[1] * res), 0, res);
272
- return iy * stride + ix;
273
- };
274
 
275
- for (let i = 0; i < folds.length; i += 1) {
276
- if (progress[i] <= 0.02) continue;
277
- const fold = folds[i];
278
- const aIdx = pointToIndex(fold.from);
279
- const bIdx = pointToIndex(fold.to);
280
- const pa = projected[aIdx];
281
- const pb = projected[bIdx];
282
 
283
- ctx.beginPath();
284
- ctx.moveTo(pa.x, pa.y);
285
- ctx.lineTo(pb.x, pb.y);
286
- ctx.strokeStyle = fold.assignment === 'M' ? MOUNTAIN_COLOR : VALLEY_COLOR;
287
- ctx.globalAlpha = clamp(0.35 + 0.65 * progress[i], 0, 1);
288
- ctx.lineWidth = 2.15;
289
- ctx.stroke();
290
- ctx.globalAlpha = 1;
291
- }
292
- }, [dim, folds, mesh, mode, totalSteps]);
293
 
294
- useEffect(() => {
295
- draw(animatedStepRef.current);
296
- }, [draw]);
 
 
297
 
298
  useEffect(() => {
299
- cancelAnimationFrame(rafRef.current);
300
- const startValue = animatedStepRef.current;
301
- const endValue = currentStep;
302
- const durationMs = 420;
303
- const startAt = performance.now();
304
-
305
- const tick = (now) => {
306
- const t = clamp((now - startAt) / durationMs, 0, 1);
307
- const eased = stepEasing(t);
308
- const value = startValue + (endValue - startValue) * eased;
309
- animatedStepRef.current = value;
310
- draw(value);
311
- if (t < 1) rafRef.current = requestAnimationFrame(tick);
312
- };
313
-
314
- rafRef.current = requestAnimationFrame(tick);
315
- return () => cancelAnimationFrame(rafRef.current);
316
- }, [currentStep, draw]);
317
 
318
  return (
319
  <canvas
 
1
+ import { useCallback, useEffect, useRef } from 'react';
2
 
3
  const PAPER_RGB = [250, 250, 245];
4
  const LIGHT_DIR = normalize3([0.4, -0.45, 1.0]);
5
+ const MOUNTAIN_COLOR = 'rgba(245, 158, 11, 0.9)';
6
+ const VALLEY_COLOR = 'rgba(56, 189, 248, 0.9)';
 
 
7
 
8
  function clamp(value, min, max) {
9
  return Math.min(Math.max(value, min), max);
 
39
  return `rgb(${r}, ${g}, ${b})`;
40
  }
41
 
42
+ function strainColor(strain, intensity) {
43
+ const t = clamp(strain / 0.15, 0, 1);
44
+ const lit = clamp(0.3 + 0.7 * Math.abs(intensity), 0, 1);
45
+ // Blend from paper ivory to red-orange with lighting
46
+ const r = Math.round((250 + t * 5) * lit);
47
+ const g = Math.round((250 - t * 200) * lit);
48
+ const bv = Math.round((245 - t * 245) * lit);
49
+ return `rgb(${clamp(r,0,255)}, ${clamp(g,0,255)}, ${clamp(bv,0,255)})`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  }
51
 
52
  function projectVertex(vertex, dim) {
53
  let x = vertex[0] - 0.5;
54
  let y = vertex[1] - 0.5;
55
+ let z = vertex[2] || 0;
56
 
57
+ const pitch = 0.62;
58
+ const yaw = -0.52;
59
 
60
  const cp = Math.cos(pitch);
61
  const sp = Math.sin(pitch);
 
77
  };
78
  }
79
 
80
+ function drawPaperState(ctx, paperState, dim) {
81
+ ctx.clearRect(0, 0, dim, dim);
82
+ ctx.fillStyle = '#121220';
83
+ ctx.fillRect(0, 0, dim, dim);
84
+
85
+ if (!paperState) {
86
+ // Draw flat sheet for initial state
87
+ const flatVerts = [[0,0,0],[1,0,0],[1,1,0],[0,1,0]];
88
+ const flatFaces = [[0,1,2],[0,2,3]];
89
+ renderMesh(ctx, flatVerts, flatFaces, null, dim);
90
+ return;
91
  }
92
 
93
+ const { vertices_coords, faces_vertices, strain_per_vertex, edges_vertices, edges_assignment } = paperState;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ if (!vertices_coords || !faces_vertices) {
96
  ctx.fillStyle = '#121220';
97
  ctx.fillRect(0, 0, dim, dim);
98
+ return;
99
+ }
100
 
101
+ renderMesh(ctx, vertices_coords, faces_vertices, strain_per_vertex, dim);
102
+
103
+ // Draw fold creases on top
104
+ if (edges_vertices && edges_assignment) {
105
+ const projected = vertices_coords.map(v => projectVertex(v, dim));
106
+ for (let i = 0; i < edges_vertices.length; i++) {
107
+ const asgn = edges_assignment[i];
108
+ if (asgn !== 'M' && asgn !== 'V') continue;
109
+ const [ai, bi] = edges_vertices[i];
110
+ const pa = projected[ai];
111
+ const pb = projected[bi];
112
+ if (!pa || !pb) continue;
113
+ ctx.beginPath();
114
+ ctx.moveTo(pa.x, pa.y);
115
+ ctx.lineTo(pb.x, pb.y);
116
+ ctx.strokeStyle = asgn === 'M' ? MOUNTAIN_COLOR : VALLEY_COLOR;
117
+ ctx.lineWidth = 2.0;
118
+ ctx.globalAlpha = 0.85;
119
+ ctx.stroke();
120
+ ctx.globalAlpha = 1;
121
  }
122
+ }
123
+ }
124
 
125
+ function renderMesh(ctx, verts, faces, strain, dim) {
126
+ const projected = verts.map(v => projectVertex(v, dim));
127
 
128
+ const tris = [];
129
+ for (const face of faces) {
130
+ // Triangulate face (fan from first vertex)
131
+ for (let k = 1; k < face.length - 1; k++) {
132
+ const a = face[0], b = face[k], c = face[k + 1];
133
+ const p0 = projected[a];
134
+ const p1 = projected[b];
135
+ const p2 = projected[c];
136
+ if (!p0 || !p1 || !p2) continue;
137
  const avgZ = (p0.z + p1.z + p2.z) / 3;
138
+ const v0 = verts[a], v1 = verts[b], v2 = verts[c];
 
 
 
139
  const normal = normalize3(cross3(sub3(v1, v0), sub3(v2, v0)));
140
  const intensity = dot3(normal, LIGHT_DIR);
141
+ const avgStrain = strain
142
+ ? ((strain[a] || 0) + (strain[b] || 0) + (strain[c] || 0)) / 3
143
+ : 0;
144
+ tris.push({ a, b, c, avgZ, intensity, avgStrain });
145
+ }
146
+ }
147
 
148
+ tris.sort((x, y) => x.avgZ - y.avgZ);
 
 
 
 
 
149
 
150
+ for (const tri of tris) {
151
+ const p0 = projected[tri.a];
152
+ const p1 = projected[tri.b];
153
+ const p2 = projected[tri.c];
154
 
155
+ ctx.beginPath();
156
+ ctx.moveTo(p0.x, p0.y);
157
+ ctx.lineTo(p1.x, p1.y);
158
+ ctx.lineTo(p2.x, p2.y);
159
+ ctx.closePath();
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ const fillColor = tri.avgStrain > 0.005
162
+ ? strainColor(tri.avgStrain, tri.intensity)
163
+ : shadePaper(tri.intensity);
 
 
 
 
164
 
165
+ ctx.fillStyle = fillColor;
166
+ ctx.fill();
167
+ ctx.strokeStyle = 'rgba(42, 42, 58, 0.22)';
168
+ ctx.lineWidth = 0.55;
169
+ ctx.stroke();
170
+ }
171
+ }
172
 
173
+ export default function Fold3DCanvas({
174
+ steps,
175
+ currentStep,
176
+ dim = 280,
177
+ }) {
178
+ const canvasRef = useRef(null);
 
 
 
 
179
 
180
+ const getPaperState = useCallback(() => {
181
+ if (!steps || steps.length === 0 || currentStep === 0) return null;
182
+ const stepData = steps[currentStep - 1];
183
+ return stepData ? stepData.paper_state : null;
184
+ }, [steps, currentStep]);
185
 
186
  useEffect(() => {
187
+ const canvas = canvasRef.current;
188
+ if (!canvas) return;
189
+ const ctx = canvas.getContext('2d');
190
+ if (!ctx) return;
191
+ drawPaperState(ctx, getPaperState(), dim);
192
+ }, [getPaperState, dim]);
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  return (
195
  <canvas
src/components/InfoBadges.js CHANGED
@@ -27,31 +27,31 @@ function NumVal({ value }) {
27
  return <span className="info-val">{value}</span>;
28
  }
29
 
30
- export default function InfoBadges({ info, targetDef }) {
 
 
 
31
  return (
32
  <div className="info-badges">
33
  <div className="info-row">
34
- <span className="info-key">n_creases</span>
35
- <NumVal value={info ? info.n_creases : (targetDef ? targetDef.n_creases : null)} />
36
  </div>
37
  <div className="info-row">
38
- <span className="info-key">interior_verts</span>
39
- <NumVal value={info ? info.n_interior_vertices : null} />
40
  </div>
41
  <div className="info-row">
42
- <span className="info-key">local_fold</span>
43
- <BoolVal value={info ? info.local_foldability : null} />
44
  </div>
45
  <div className="info-row">
46
- <span className="info-key">blb_sat</span>
47
- <BoolVal value={info ? info.blb_satisfied : null} />
48
  </div>
49
  <div className="info-row">
50
- <span className="info-key">global_fold</span>
51
- <TextVal
52
- value={info ? info.global_foldability : null}
53
- dim={true}
54
- />
55
  </div>
56
  {targetDef && (
57
  <>
@@ -60,9 +60,13 @@ export default function InfoBadges({ info, targetDef }) {
60
  <span className="info-val">LVL {targetDef.level}</span>
61
  </div>
62
  <div className="info-row">
63
- <span className="info-key">target</span>
 
 
 
 
64
  <span className="info-val" style={{ fontSize: '10px', textAlign: 'right', maxWidth: '140px', wordBreak: 'break-word' }}>
65
- {targetDef.name.replace(/_/g, ' ').toUpperCase()}
66
  </span>
67
  </div>
68
  </>
 
27
  return <span className="info-val">{value}</span>;
28
  }
29
 
30
+ export default function InfoBadges({ metrics, paperState, targetDef }) {
31
+ const numLayers = paperState?.num_layers ?? metrics?.num_layers ?? null;
32
+ const foldCount = metrics?.fold_count ?? paperState?.fold_count ?? null;
33
+
34
  return (
35
  <div className="info-badges">
36
  <div className="info-row">
37
+ <span className="info-key">fold_count</span>
38
+ <NumVal value={foldCount} />
39
  </div>
40
  <div className="info-row">
41
+ <span className="info-key">num_layers</span>
42
+ <NumVal value={numLayers} />
43
  </div>
44
  <div className="info-row">
45
+ <span className="info-key">is_valid</span>
46
+ <BoolVal value={metrics ? metrics.is_valid : null} />
47
  </div>
48
  <div className="info-row">
49
+ <span className="info-key">strain_exceeded</span>
50
+ <BoolVal value={metrics ? metrics.strain_exceeded : null} />
51
  </div>
52
  <div className="info-row">
53
+ <span className="info-key">is_deployable</span>
54
+ <BoolVal value={metrics ? metrics.is_deployable : null} />
 
 
 
55
  </div>
56
  {targetDef && (
57
  <>
 
60
  <span className="info-val">LVL {targetDef.level}</span>
61
  </div>
62
  <div className="info-row">
63
+ <span className="info-key">material</span>
64
+ <TextVal value={targetDef.material} dim={true} />
65
+ </div>
66
+ <div className="info-row">
67
+ <span className="info-key">task</span>
68
  <span className="info-val" style={{ fontSize: '10px', textAlign: 'right', maxWidth: '140px', wordBreak: 'break-word' }}>
69
+ {(targetDef.name || '').replace(/_/g, ' ').toUpperCase()}
70
  </span>
71
  </div>
72
  </>
src/components/RewardPanel.js CHANGED
@@ -1,50 +1,89 @@
1
- const REWARD_FIELDS = [
2
- { key: 'kawasaki', label: 'kawasaki', color: 'var(--validity)' },
3
- { key: 'maekawa', label: 'maekawa', color: 'var(--validity)' },
4
- { key: 'blb', label: 'blb', color: 'var(--validity)' },
5
- { key: 'progress', label: 'progress', color: 'var(--progress)' },
6
- { key: 'economy', label: 'economy', color: 'var(--economy)' },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ];
8
 
9
- const TOTAL_FIELD = { key: 'total', label: 'total', color: 'var(--text-primary)' };
10
-
11
- function RewardRow({ label, color, value }) {
12
- const isDash = value === null || value === undefined;
13
- const pct = isDash ? 0 : Math.min(Math.max(value, 0), 1) * 100;
14
-
15
  return (
16
  <div className="reward-row">
17
  <span className="reward-label">{label}</span>
18
  <div className="reward-track">
19
  <div
20
  className="reward-bar"
21
- style={{ width: `${pct}%`, background: color }}
22
  />
23
  </div>
24
  <span className={`reward-value${isDash ? ' dim' : ''}`}>
25
- {isDash ? '—' : value.toFixed(2)}
26
  </span>
27
  </div>
28
  );
29
  }
30
 
31
- export default function RewardPanel({ reward }) {
32
  return (
33
  <div className="reward-panel">
34
- {REWARD_FIELDS.map(({ key, label, color }) => (
35
- <RewardRow
36
- key={key}
37
- label={label}
38
- color={color}
39
- value={reward ? reward[key] : null}
40
- />
41
- ))}
42
- <div className="reward-divider" />
43
- <RewardRow
44
- label={TOTAL_FIELD.label}
45
- color={TOTAL_FIELD.color}
46
- value={reward ? reward[TOTAL_FIELD.key] : null}
47
- />
 
 
48
  </div>
49
  );
50
  }
 
1
+ const METRIC_FIELDS = [
2
+ {
3
+ key: 'compactness',
4
+ label: 'compactness',
5
+ color: 'var(--progress)',
6
+ normalize: (v) => Math.min(Math.max(v || 0, 0), 1),
7
+ format: (v) => (v != null ? v.toFixed(3) : '—'),
8
+ },
9
+ {
10
+ key: 'max_strain',
11
+ label: 'max strain',
12
+ color: 'var(--validity)',
13
+ // Show as inverted bar: low strain = small bar (good)
14
+ normalize: (v) => Math.min((v || 0) / 0.2, 1),
15
+ format: (v) => (v != null ? v.toFixed(4) : '—'),
16
+ inverted: true,
17
+ },
18
+ {
19
+ key: 'kawasaki_violations',
20
+ label: 'kawasaki',
21
+ color: 'var(--validity)',
22
+ normalize: (v) => Math.min((v || 0) / 5, 1),
23
+ format: (v) => (v != null ? String(v) : '—'),
24
+ inverted: true,
25
+ },
26
+ {
27
+ key: 'maekawa_violations',
28
+ label: 'maekawa',
29
+ color: 'var(--validity)',
30
+ normalize: (v) => Math.min((v || 0) / 5, 1),
31
+ format: (v) => (v != null ? String(v) : '—'),
32
+ inverted: true,
33
+ },
34
+ {
35
+ key: 'fits_target_box',
36
+ label: 'fits box',
37
+ color: 'var(--progress)',
38
+ normalize: (v) => (v ? 1 : 0),
39
+ format: (v) => (v == null ? '—' : v ? 'YES' : 'NO'),
40
+ },
41
+ {
42
+ key: 'is_deployable',
43
+ label: 'deployable',
44
+ color: 'var(--progress)',
45
+ normalize: (v) => (v ? 1 : 0),
46
+ format: (v) => (v == null ? '—' : v ? 'YES' : 'NO'),
47
+ },
48
  ];
49
 
50
+ function RewardRow({ label, color, pct, formattedValue, isDash, inverted }) {
51
+ const barColor = inverted && pct > 0 ? 'var(--validity)' : color;
 
 
 
 
52
  return (
53
  <div className="reward-row">
54
  <span className="reward-label">{label}</span>
55
  <div className="reward-track">
56
  <div
57
  className="reward-bar"
58
+ style={{ width: `${isDash ? 0 : pct}%`, background: barColor }}
59
  />
60
  </div>
61
  <span className={`reward-value${isDash ? ' dim' : ''}`}>
62
+ {formattedValue}
63
  </span>
64
  </div>
65
  );
66
  }
67
 
68
+ export default function RewardPanel({ metrics }) {
69
  return (
70
  <div className="reward-panel">
71
+ {METRIC_FIELDS.map(({ key, label, color, normalize, format, inverted }) => {
72
+ const raw = metrics ? metrics[key] : undefined;
73
+ const isDash = raw === null || raw === undefined;
74
+ const pct = isDash ? 0 : normalize(raw) * 100;
75
+ return (
76
+ <RewardRow
77
+ key={key}
78
+ label={label}
79
+ color={color}
80
+ pct={pct}
81
+ formattedValue={isDash ? '—' : format(raw)}
82
+ isDash={isDash}
83
+ inverted={!!inverted}
84
+ />
85
+ );
86
+ })}
87
  </div>
88
  );
89
  }
src/components/StepFeed.js CHANGED
@@ -1,14 +1,32 @@
1
  import { useEffect, useRef } from 'react';
2
 
3
- function rewardDelta(step, prevStep) {
4
- if (!step || !step.reward) return null;
5
- const curr = step.reward.total;
6
- if (prevStep && prevStep.reward) {
7
- return curr - prevStep.reward.total;
 
8
  }
9
  return curr;
10
  }
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  export default function StepFeed({ steps, currentStep }) {
13
  const feedRef = useRef(null);
14
  const activeRef = useRef(null);
@@ -34,9 +52,10 @@ export default function StepFeed({ steps, currentStep }) {
34
  {steps.map((step, idx) => {
35
  const stepNum = idx + 1;
36
  const isActive = currentStep === stepNum;
37
- const delta = rewardDelta(step, idx > 0 ? steps[idx - 1] : null);
38
- const asgn = step.fold ? step.fold.assignment : null;
39
- const instruction = step.fold ? step.fold.instruction : (step.prompt || '');
 
40
 
41
  return (
42
  <div
@@ -46,21 +65,23 @@ export default function StepFeed({ steps, currentStep }) {
46
  >
47
  <div className="step-entry-top">
48
  <span className="step-num">#{stepNum}</span>
49
- <span className="step-instruction">{instruction}</span>
50
  {asgn && (
51
  <span className={`assign-badge ${asgn}`}>{asgn}</span>
52
  )}
53
  </div>
54
  {delta !== null && (
55
  <div className="step-reward-delta">
56
- {'\u0394'} total:{' '}
57
  <span className={delta >= 0 ? 'delta-positive' : 'delta-negative'}>
58
  {delta >= 0 ? '+' : ''}{delta.toFixed(3)}
59
  </span>
60
- {step.reward && (
61
  <span style={{ color: 'var(--text-dim)' }}>
62
- {' '}| progress: {step.reward.progress.toFixed(2)}
63
- {' '}| economy: {step.reward.economy.toFixed(2)}
 
 
64
  </span>
65
  )}
66
  </div>
 
1
  import { useEffect, useRef } from 'react';
2
 
3
+ function compactnessDelta(step, prevStep) {
4
+ if (!step || !step.metrics) return null;
5
+ const curr = step.metrics.compactness;
6
+ if (curr == null) return null;
7
+ if (prevStep && prevStep.metrics && prevStep.metrics.compactness != null) {
8
+ return curr - prevStep.metrics.compactness;
9
  }
10
  return curr;
11
  }
12
 
13
+ function foldAssignment(fold) {
14
+ if (!fold) return null;
15
+ const t = fold.type || '';
16
+ if (t === 'valley') return 'V';
17
+ if (t === 'mountain') return 'M';
18
+ if (t === 'pleat') return 'P';
19
+ if (t === 'crimp') return 'C';
20
+ return t.charAt(0).toUpperCase() || null;
21
+ }
22
+
23
+ function foldLabel(fold) {
24
+ if (!fold) return '';
25
+ const type = fold.type || 'fold';
26
+ const angle = fold.angle != null ? ` ${fold.angle}°` : '';
27
+ return `${type.toUpperCase()} FOLD${angle}`;
28
+ }
29
+
30
  export default function StepFeed({ steps, currentStep }) {
31
  const feedRef = useRef(null);
32
  const activeRef = useRef(null);
 
52
  {steps.map((step, idx) => {
53
  const stepNum = idx + 1;
54
  const isActive = currentStep === stepNum;
55
+ const delta = compactnessDelta(step, idx > 0 ? steps[idx - 1] : null);
56
+ const asgn = foldAssignment(step.fold);
57
+ const label = foldLabel(step.fold);
58
+ const m = step.metrics || {};
59
 
60
  return (
61
  <div
 
65
  >
66
  <div className="step-entry-top">
67
  <span className="step-num">#{stepNum}</span>
68
+ <span className="step-instruction">{label}</span>
69
  {asgn && (
70
  <span className={`assign-badge ${asgn}`}>{asgn}</span>
71
  )}
72
  </div>
73
  {delta !== null && (
74
  <div className="step-reward-delta">
75
+ {'\u0394'} compact:{' '}
76
  <span className={delta >= 0 ? 'delta-positive' : 'delta-negative'}>
77
  {delta >= 0 ? '+' : ''}{delta.toFixed(3)}
78
  </span>
79
+ {m.max_strain != null && (
80
  <span style={{ color: 'var(--text-dim)' }}>
81
+ {' '}| strain: {m.max_strain.toFixed(4)}
82
+ {m.is_valid != null && (
83
+ <span> | {m.is_valid ? '✓' : '✗'}</span>
84
+ )}
85
  </span>
86
  )}
87
  </div>
training/__init__.py ADDED
File without changes
training/demo.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training/demo.py — Run 8 zero-shot rollouts and stream them to the grid viewer.
3
+
4
+ Usage:
5
+ cd /path/to/optigami
6
+ python -m training.demo
7
+
8
+ Then open: http://localhost:9001/viewer/training.html
9
+
10
+ Each of the 8 "strategies" is a heuristic that mimics what a pretrained LLM might
11
+ produce for different tasks — varying from near-optimal to poor. This exercises
12
+ the full broadcast → grid viewer pipeline without requiring an LLM API key.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import time
18
+ import uuid
19
+ from typing import Callable
20
+
21
+ import uvicorn
22
+
23
+ from server.app import app, broadcast
24
+ from training.runner import run_batch
25
+
26
+
27
+ # ── 8 zero-shot heuristic strategies ──────────────────────────────────────────
28
+ # Each is a callable: paper_state (dict) → fold_dict
29
+ # These represent the range of strategies a pretrained LLM might generate.
30
+
31
+ def strategy_perfect_half(paper_state: dict) -> dict:
32
+ """Valley fold exactly at horizontal midline — optimal for half_fold."""
33
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
34
+
35
+
36
+ def strategy_slight_offset(paper_state: dict) -> dict:
37
+ """Valley fold slightly off-center — almost optimal."""
38
+ return {"type": "valley", "line": {"start": [0.0, 0.48], "end": [1.0, 0.48]}, "angle": 180.0}
39
+
40
+
41
+ def strategy_thirds(paper_state: dict) -> dict:
42
+ """Letter fold at one-third — wrong for half_fold, generates interesting geometry."""
43
+ fold_count = paper_state.get("fold_count", 0)
44
+ positions = [0.333, 0.667]
45
+ if fold_count >= len(positions):
46
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
47
+ return {
48
+ "type": "valley" if fold_count == 0 else "mountain",
49
+ "line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
50
+ "angle": 180.0,
51
+ }
52
+
53
+
54
+ def strategy_vertical(paper_state: dict) -> dict:
55
+ """Vertical fold — gets compactness but in wrong dimension for target_box."""
56
+ return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
57
+
58
+
59
+ def strategy_mountain(paper_state: dict) -> dict:
60
+ """Mountain fold at midline — same geometry, different assignment."""
61
+ return {"type": "mountain", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
62
+
63
+
64
+ def strategy_accordion(paper_state: dict) -> dict:
65
+ """Accordion 3-fold — overfolds, achieves high compactness but more folds."""
66
+ fold_count = paper_state.get("fold_count", 0)
67
+ positions = [0.25, 0.5, 0.75]
68
+ assignments = ["valley", "mountain", "valley"]
69
+ if fold_count >= len(positions):
70
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
71
+ return {
72
+ "type": assignments[fold_count],
73
+ "line": {"start": [0.0, positions[fold_count]], "end": [1.0, positions[fold_count]]},
74
+ "angle": 180.0,
75
+ }
76
+
77
+
78
+ def strategy_diagonal(paper_state: dict) -> dict:
79
+ """Diagonal fold — achieves compactness but irregular bounding box."""
80
+ return {"type": "valley", "line": {"start": [0.0, 0.0], "end": [1.0, 1.0]}, "angle": 180.0}
81
+
82
+
83
+ def strategy_quarter(paper_state: dict) -> dict:
84
+ """Two perpendicular folds — 4x compactness for quarter_fold task."""
85
+ fold_count = paper_state.get("fold_count", 0)
86
+ if fold_count == 0:
87
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
88
+ if fold_count == 1:
89
+ return {"type": "valley", "line": {"start": [0.5, 0.0], "end": [0.5, 1.0]}, "angle": 180.0}
90
+ return {"type": "stop", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 0.0}
91
+
92
+
93
+ STRATEGIES: list[tuple[str, Callable]] = [
94
+ ("perfect_half", strategy_perfect_half),
95
+ ("slight_offset", strategy_slight_offset),
96
+ ("thirds_fold", strategy_thirds),
97
+ ("vertical_fold", strategy_vertical),
98
+ ("mountain_fold", strategy_mountain),
99
+ ("accordion_3", strategy_accordion),
100
+ ("diagonal", strategy_diagonal),
101
+ ("quarter_fold", strategy_quarter),
102
+ ]
103
+
104
+
105
+ # ── Demo runner ────────────────────────────────────────────────────────────────
106
+
107
+ async def run_demo(task_name: str = "half_fold", delay_s: float = 0.5) -> None:
108
+ """Wait for server to be ready, then fire 8 episodes."""
109
+ # Give uvicorn time to bind and call startup hook (sets broadcast._loop)
110
+ await asyncio.sleep(1.5)
111
+
112
+ batch_id = 1
113
+ names, fns = zip(*STRATEGIES)
114
+ ep_ids = [f"ep_{name}" for name in names]
115
+
116
+ print(f"\n[demo] Starting batch {batch_id} — task: {task_name}")
117
+ print(f"[demo] Open http://localhost:9001/viewer/training.html\n")
118
+
119
+ # Signal grid to clear and show G=8
120
+ await broadcast.start_batch(batch_id, len(fns))
121
+
122
+ await asyncio.sleep(delay_s)
123
+
124
+ # Run all 8 episodes in the thread pool; broadcast_fn fires into this loop
125
+ results = await asyncio.gather(*[
126
+ asyncio.to_thread(
127
+ _run_one,
128
+ fn,
129
+ task_name,
130
+ ep_id,
131
+ broadcast.publish,
132
+ )
133
+ for fn, ep_id in zip(fns, ep_ids)
134
+ ])
135
+
136
+ scores = [r["score"] for r in results]
137
+ best_idx = max(range(len(scores)), key=lambda i: scores[i])
138
+
139
+ await broadcast.finish_batch(batch_id, scores, best_episode_id=ep_ids[best_idx])
140
+
141
+ print("\n[demo] Results:")
142
+ for name, result in zip(names, results):
143
+ print(f" {name:20s} score={result['score']:+.2f} status={result['status']}")
144
+ print(f"\n[demo] Best: {names[best_idx]} (score={scores[best_idx]:+.2f})")
145
+ print("\n[demo] Grid viewer running. Press Ctrl+C to stop.\n")
146
+
147
+
148
+ def _run_one(
149
+ strategy_fn: Callable,
150
+ task_name: str,
151
+ ep_id: str,
152
+ broadcast_fn: Callable,
153
+ ) -> dict:
154
+ """Thin wrapper: adds a small sleep between steps so the viewer can animate."""
155
+ from server.models import OrigamiAction
156
+ from server.origami_environment import OrigamiEnvironment
157
+
158
+ env = OrigamiEnvironment()
159
+ obs = env.reset(task_name=task_name)
160
+
161
+ broadcast_fn(ep_id, {
162
+ "type": "episode_update",
163
+ "episode_id": ep_id,
164
+ "task_name": task_name,
165
+ "step": 0,
166
+ "observation": _obs_dict(obs),
167
+ })
168
+
169
+ max_steps = env._task.get("max_folds", 10) if env._task else 10
170
+ status = "done"
171
+
172
+ for step_idx in range(max_steps):
173
+ if obs.done:
174
+ break
175
+
176
+ time.sleep(0.3) # pace so the viewer can animate each step
177
+
178
+ fold_dict = strategy_fn(obs.paper_state)
179
+
180
+ if fold_dict.get("type") == "stop":
181
+ break
182
+
183
+ action = OrigamiAction(
184
+ fold_type=fold_dict["type"],
185
+ fold_line=fold_dict["line"],
186
+ fold_angle=float(fold_dict.get("angle", 180.0)),
187
+ )
188
+ obs = env.step(action)
189
+
190
+ broadcast_fn(ep_id, {
191
+ "type": "episode_update",
192
+ "episode_id": ep_id,
193
+ "task_name": task_name,
194
+ "step": step_idx + 1,
195
+ "observation": _obs_dict(obs),
196
+ })
197
+
198
+ if obs.done:
199
+ break
200
+ else:
201
+ status = "timeout"
202
+
203
+ score = obs.reward if obs.reward is not None else env._total_reward or 0.0
204
+
205
+ broadcast_fn(ep_id, {
206
+ "type": "episode_done",
207
+ "episode_id": ep_id,
208
+ "status": status,
209
+ "score": float(score),
210
+ "final_metrics": obs.metrics,
211
+ })
212
+
213
+ return {
214
+ "episode_id": ep_id,
215
+ "score": float(score),
216
+ "final_metrics": obs.metrics,
217
+ "status": status,
218
+ }
219
+
220
+
221
+ def _obs_dict(obs) -> dict:
222
+ try:
223
+ return obs.model_dump()
224
+ except AttributeError:
225
+ return {
226
+ "paper_state": getattr(obs, "paper_state", {}),
227
+ "metrics": getattr(obs, "metrics", {}),
228
+ "fold_history": getattr(obs, "fold_history", []),
229
+ "done": getattr(obs, "done", False),
230
+ "reward": getattr(obs, "reward", None),
231
+ }
232
+
233
+
234
+ # ── Entry point ────────────────────────────────────────────────────────────────
235
+
236
+ async def _main() -> None:
237
+ config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
238
+ server = uvicorn.Server(config)
239
+
240
+ # Run demo concurrently with the uvicorn server
241
+ await asyncio.gather(
242
+ server.serve(),
243
+ run_demo(task_name="half_fold"),
244
+ )
245
+
246
+
247
+ if __name__ == "__main__":
248
+ try:
249
+ asyncio.run(_main())
250
+ except KeyboardInterrupt:
251
+ print("\n[demo] Stopped.")
training/demo_llm.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training/demo_llm.py — 8 rollouts using Claude as the zero-shot fold strategist.
3
+
4
+ Usage:
5
+ cd /path/to/optigami
6
+ ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
7
+
8
+ Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
9
+ Claude sees the current paper_state metrics and decides the next fold.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import os
16
+ import re
17
+ import time
18
+ from typing import Any
19
+
20
+ import anthropic
21
+ import uvicorn
22
+
23
+ from server.app import app, broadcast
24
+ from server.models import OrigamiAction
25
+ from server.origami_environment import OrigamiEnvironment
26
+ from server.tasks import get_task_by_name
27
+
28
+
29
+ TASK_NAME = "half_fold"
30
+ NUM_EPISODES = 8
31
+ MODEL = "claude-haiku-4-5-20251001"
32
+
33
+
34
+ # ── LLM strategy factory ───────────────────────────────────────────────────────
35
+
36
+ def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
37
+ """Return a strategy_fn for one episode. Each episode gets its own call history."""
38
+ history: list[dict[str, Any]] = []
39
+
40
+ def strategy(paper_state: dict) -> dict:
41
+ fold_count = paper_state.get("fold_count", 0)
42
+ compactness = paper_state.get("compactness", 0)
43
+ bb = paper_state.get("bounding_box", [1, 1, 0])
44
+ target_box = task.get("target_box", [1, 0.5, 0.02])
45
+ max_folds = task.get("max_folds", 3)
46
+
47
+ user_msg = f"""You are folding a {task['width']}x{task['height']} sheet of {task['material']}.
48
+ Task: {task['description']}
49
+ Target box to fit inside: {target_box}
50
+ Max folds allowed: {max_folds}
51
+
52
+ Current state (fold {fold_count}/{max_folds}):
53
+ compactness: {compactness:.3f} (1.0 = fully packed, 0.0 = flat)
54
+ bounding_box: [{bb[0]:.3f}, {bb[1]:.3f}, {bb[2]:.4f}]
55
+ fits_target_box: {paper_state.get('fits_target_box', False)}
56
+
57
+ Choose the next fold. Respond with ONLY valid JSON, no other text:
58
+ {{
59
+ "type": "valley" or "mountain" or "stop",
60
+ "line": {{"start": [x, y], "end": [x, y]}},
61
+ "angle": 180
62
+ }}
63
+
64
+ Coordinates are normalized 0-1. Use "stop" if done."""
65
+
66
+ history.append({"role": "user", "content": user_msg})
67
+
68
+ response = client.messages.create(
69
+ model=MODEL,
70
+ max_tokens=120,
71
+ messages=history,
72
+ )
73
+ reply = response.content[0].text.strip()
74
+ history.append({"role": "assistant", "content": reply})
75
+
76
+ # Extract JSON — handle markdown code blocks
77
+ match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
78
+ if not match:
79
+ return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
80
+
81
+ fold_dict = json.loads(match.group())
82
+ # Normalize: ensure required keys
83
+ fold_dict.setdefault("type", "valley")
84
+ fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
85
+ fold_dict.setdefault("angle", 180.0)
86
+ return fold_dict
87
+
88
+ return strategy
89
+
90
+
91
+ # ── Episode runner ─────────────────────────────────────────────────────────────
92
+
93
+ def run_episode_llm(
94
+ strategy_fn,
95
+ task_name: str,
96
+ ep_id: str,
97
+ broadcast_fn,
98
+ ) -> dict:
99
+ env = OrigamiEnvironment()
100
+ obs = env.reset(task_name=task_name)
101
+ task = env._task or {}
102
+
103
+ broadcast_fn(ep_id, {
104
+ "type": "episode_update",
105
+ "episode_id": ep_id,
106
+ "task_name": task_name,
107
+ "step": 0,
108
+ "observation": _obs_dict(obs),
109
+ })
110
+
111
+ max_steps = task.get("max_folds", 5)
112
+ status = "done"
113
+
114
+ for step_idx in range(max_steps):
115
+ if obs.done:
116
+ break
117
+
118
+ # Build a flat paper_state dict for the LLM (add metrics inline)
119
+ ps = dict(obs.paper_state)
120
+ ps.update(obs.metrics) # compactness, fits_target_box, etc.
121
+ ps["fold_count"] = step_idx
122
+
123
+ try:
124
+ fold_dict = strategy_fn(ps)
125
+ except Exception as exc:
126
+ broadcast_fn(ep_id, {
127
+ "type": "episode_done", "episode_id": ep_id,
128
+ "status": "error", "score": 0.0,
129
+ "final_metrics": obs.metrics, "error": str(exc),
130
+ })
131
+ return {"episode_id": ep_id, "score": 0.0, "status": "error"}
132
+
133
+ if fold_dict.get("type") == "stop":
134
+ break
135
+
136
+ time.sleep(0.4) # pace for viewer animation
137
+
138
+ action = OrigamiAction(
139
+ fold_type=fold_dict["type"],
140
+ fold_line=fold_dict["line"],
141
+ fold_angle=float(fold_dict.get("angle", 180.0)),
142
+ )
143
+ obs = env.step(action)
144
+
145
+ broadcast_fn(ep_id, {
146
+ "type": "episode_update",
147
+ "episode_id": ep_id,
148
+ "task_name": task_name,
149
+ "step": step_idx + 1,
150
+ "observation": _obs_dict(obs),
151
+ })
152
+
153
+ if obs.done:
154
+ break
155
+ else:
156
+ status = "timeout"
157
+
158
+ score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
159
+ broadcast_fn(ep_id, {
160
+ "type": "episode_done",
161
+ "episode_id": ep_id,
162
+ "status": status,
163
+ "score": float(score),
164
+ "final_metrics": obs.metrics,
165
+ })
166
+
167
+ return {"episode_id": ep_id, "score": float(score), "status": status}
168
+
169
+
170
+ def _obs_dict(obs) -> dict:
171
+ try:
172
+ return obs.model_dump()
173
+ except AttributeError:
174
+ return {
175
+ "paper_state": getattr(obs, "paper_state", {}),
176
+ "metrics": getattr(obs, "metrics", {}),
177
+ "fold_history": getattr(obs, "fold_history", []),
178
+ "done": getattr(obs, "done", False),
179
+ "reward": getattr(obs, "reward", None),
180
+ }
181
+
182
+
183
+ # ── Main ──────────────────────────────────────────────────────────────────────
184
+
185
+ async def run_demo() -> None:
186
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
187
+ if not api_key:
188
+ raise RuntimeError("Set ANTHROPIC_API_KEY environment variable")
189
+
190
+ client = anthropic.Anthropic(api_key=api_key)
191
+ task = get_task_by_name(TASK_NAME)
192
+
193
+ await asyncio.sleep(1.5) # wait for server startup
194
+
195
+ print(f"\n[llm-demo] Model: {MODEL}")
196
+ print(f"[llm-demo] Task: {TASK_NAME} — {task['description']}")
197
+ print(f"[llm-demo] Open: http://localhost:9001/viewer/training.html\n")
198
+
199
+ await broadcast.start_batch(1, NUM_EPISODES)
200
+
201
+ ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
202
+ strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
203
+
204
+ # Run all episodes concurrently (each makes its own Claude API calls)
205
+ results = await asyncio.gather(*[
206
+ asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
207
+ for fn, ep_id in zip(strategies, ep_ids)
208
+ ])
209
+
210
+ scores = [r["score"] for r in results]
211
+ best_idx = max(range(len(scores)), key=lambda i: scores[i])
212
+
213
+ await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
214
+
215
+ print("\n[llm-demo] Results:")
216
+ for i, result in enumerate(results):
217
+ print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']}")
218
+ print(f"\n[llm-demo] Best: ep_{best_idx:02d} (score={scores[best_idx]:+.2f})")
219
+ print("\n[llm-demo] Press Ctrl+C to stop.\n")
220
+
221
+
222
+ async def _main() -> None:
223
+ config = uvicorn.Config(app, host="0.0.0.0", port=9001, log_level="warning")
224
+ server = uvicorn.Server(config)
225
+ await asyncio.gather(server.serve(), run_demo())
226
+
227
+
228
+ if __name__ == "__main__":
229
+ try:
230
+ asyncio.run(_main())
231
+ except KeyboardInterrupt:
232
+ print("\n[llm-demo] Stopped.")
training/runner.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TrainingRunner — parallel episode executor for GRPO training.
3
+
4
+ Each episode runs in a ThreadPoolExecutor thread.
5
+ After every env.step(), observations are pushed to the broadcast server (fire-and-forget).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import uuid
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from typing import Any, Callable, Optional
12
+
13
+ from server.models import OrigamiAction
14
+ from server.origami_environment import OrigamiEnvironment
15
+
16
+
17
+ BroadcastFn = Callable[[str, dict], None]
18
+
19
+
20
+ def run_episode(
21
+ strategy_fn: Callable[[dict], dict],
22
+ task_name: str,
23
+ ep_id: Optional[str] = None,
24
+ broadcast_fn: Optional[BroadcastFn] = None,
25
+ max_steps: Optional[int] = None,
26
+ ) -> dict:
27
+ """Run a single origami episode with a given strategy function.
28
+
29
+ Args:
30
+ strategy_fn: Callable that receives paper_state dict and returns a fold dict:
31
+ {"type": "valley"|"mountain"|"pleat"|"crimp"|"stop",
32
+ "line": {"start": [x, y], "end": [x, y]},
33
+ "angle": 180.0}
34
+ task_name: Name of the task (from server/tasks.py)
35
+ ep_id: Episode identifier for broadcast; auto-generated if None
36
+ broadcast_fn: Optional callback(ep_id, data) for live streaming
37
+ max_steps: Override task's max_folds if provided
38
+
39
+ Returns:
40
+ dict with keys: episode_id, score, final_metrics, fold_history, status
41
+ """
42
+ ep_id = ep_id or str(uuid.uuid4())[:8]
43
+ env = OrigamiEnvironment()
44
+
45
+ obs = env.reset(task_name=task_name)
46
+
47
+ if broadcast_fn:
48
+ broadcast_fn(ep_id, {
49
+ "type": "episode_update",
50
+ "episode_id": ep_id,
51
+ "task_name": task_name,
52
+ "step": 0,
53
+ "observation": _obs_to_dict(obs),
54
+ })
55
+
56
+ step_limit = max_steps or env._task.get("max_folds", 20) if env._task else 20
57
+ status = "done"
58
+
59
+ for step_idx in range(step_limit):
60
+ if obs.done:
61
+ break
62
+
63
+ # Strategy generates a fold dict
64
+ try:
65
+ fold_dict = strategy_fn(obs.paper_state)
66
+ except Exception as exc:
67
+ status = "error"
68
+ if broadcast_fn:
69
+ broadcast_fn(ep_id, {
70
+ "type": "episode_done",
71
+ "episode_id": ep_id,
72
+ "status": "error",
73
+ "score": obs.reward or 0.0,
74
+ "final_metrics": obs.metrics,
75
+ "error": str(exc),
76
+ })
77
+ break
78
+
79
+ fold_type = fold_dict.get("type", "valley")
80
+ fold_line = fold_dict.get("line", {"start": [0, 0.5], "end": [1, 0.5]})
81
+ fold_angle = float(fold_dict.get("angle", 180.0))
82
+
83
+ action = OrigamiAction(
84
+ fold_type=fold_type,
85
+ fold_line=fold_line,
86
+ fold_angle=fold_angle,
87
+ )
88
+ obs = env.step(action)
89
+
90
+ if broadcast_fn:
91
+ broadcast_fn(ep_id, {
92
+ "type": "episode_update",
93
+ "episode_id": ep_id,
94
+ "task_name": task_name,
95
+ "step": step_idx + 1,
96
+ "observation": _obs_to_dict(obs),
97
+ })
98
+
99
+ if obs.done:
100
+ break
101
+ else:
102
+ status = "timeout"
103
+
104
+ score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
105
+
106
+ if broadcast_fn:
107
+ broadcast_fn(ep_id, {
108
+ "type": "episode_done",
109
+ "episode_id": ep_id,
110
+ "status": status,
111
+ "score": float(score),
112
+ "final_metrics": obs.metrics,
113
+ })
114
+
115
+ return {
116
+ "episode_id": ep_id,
117
+ "score": float(score),
118
+ "final_metrics": obs.metrics,
119
+ "fold_history": obs.fold_history,
120
+ "status": status,
121
+ }
122
+
123
+
124
+ def run_batch(
125
+ strategy_fns: list[Callable[[dict], dict]],
126
+ task_name: str,
127
+ broadcast_fn: Optional[BroadcastFn] = None,
128
+ batch_id: Optional[int] = None,
129
+ max_workers: int = 8,
130
+ ) -> list[dict]:
131
+ """Run G episodes in parallel with a ThreadPoolExecutor.
132
+
133
+ Args:
134
+ strategy_fns: List of G strategy callables (one per completion)
135
+ task_name: Task to use for all episodes
136
+ broadcast_fn: Optional broadcast callback, called after each step
137
+ batch_id: Batch identifier for broadcast
138
+ max_workers: Max parallel threads (bounded by G)
139
+
140
+ Returns:
141
+ List of episode result dicts, in same order as strategy_fns
142
+ """
143
+ n = len(strategy_fns)
144
+ ep_ids = [f"ep_{(batch_id or 0):04d}_{i:02d}" for i in range(n)]
145
+ workers = min(max_workers, n)
146
+
147
+ results: list[dict] = [{}] * n
148
+
149
+ with ThreadPoolExecutor(max_workers=workers) as pool:
150
+ futures = {
151
+ pool.submit(
152
+ run_episode,
153
+ fn,
154
+ task_name,
155
+ ep_ids[i],
156
+ broadcast_fn,
157
+ ): i
158
+ for i, fn in enumerate(strategy_fns)
159
+ }
160
+
161
+ for future in as_completed(futures):
162
+ idx = futures[future]
163
+ try:
164
+ results[idx] = future.result()
165
+ except Exception as exc:
166
+ results[idx] = {
167
+ "episode_id": ep_ids[idx],
168
+ "score": 0.0,
169
+ "final_metrics": {},
170
+ "fold_history": [],
171
+ "status": "error",
172
+ "error": str(exc),
173
+ }
174
+
175
+ return results
176
+
177
+
178
+ def _obs_to_dict(obs) -> dict:
179
+ """Convert OrigamiObservation to a JSON-serializable dict."""
180
+ try:
181
+ return obs.model_dump()
182
+ except AttributeError:
183
+ return {
184
+ "task": obs.task if hasattr(obs, "task") else {},
185
+ "paper_state": obs.paper_state if hasattr(obs, "paper_state") else {},
186
+ "metrics": obs.metrics if hasattr(obs, "metrics") else {},
187
+ "fold_history": obs.fold_history if hasattr(obs, "fold_history") else [],
188
+ "done": obs.done if hasattr(obs, "done") else False,
189
+ "reward": obs.reward if hasattr(obs, "reward") else None,
190
+ "error": obs.error if hasattr(obs, "error") else None,
191
+ }
viewer/training.html ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>OPTIGAMI — TRAINING GRID VIEWER</title>
7
+ <style>
8
+ :root {
9
+ --bg: #0d0d1a;
10
+ --panel: #13131f;
11
+ --border: #1e1e2e;
12
+ --text: #e2e8f0;
13
+ --dim: #4a5568;
14
+ --cyan: #38bdf8;
15
+ --amber: #f59e0b;
16
+ --green: #22c55e;
17
+ --red: #ef4444;
18
+ --font: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', monospace;
19
+ }
20
+
21
+ * { box-sizing: border-box; margin: 0; padding: 0; }
22
+
23
+ body {
24
+ background: var(--bg);
25
+ color: var(--text);
26
+ font-family: var(--font);
27
+ font-size: 11px;
28
+ min-height: 100vh;
29
+ display: flex;
30
+ flex-direction: column;
31
+ }
32
+
33
+ /* Header */
34
+ header {
35
+ display: flex;
36
+ align-items: center;
37
+ gap: 16px;
38
+ padding: 10px 16px;
39
+ background: var(--panel);
40
+ border-bottom: 1px solid var(--border);
41
+ flex-shrink: 0;
42
+ }
43
+
44
+ .logo {
45
+ font-size: 14px;
46
+ letter-spacing: 2px;
47
+ font-weight: 700;
48
+ }
49
+
50
+ .logo .accent { color: var(--cyan); }
51
+
52
+ .header-sep { width: 1px; height: 20px; background: var(--border); }
53
+
54
+ .badge {
55
+ padding: 2px 8px;
56
+ border-radius: 3px;
57
+ font-size: 10px;
58
+ letter-spacing: 1px;
59
+ font-weight: 600;
60
+ }
61
+
62
+ .badge-training { background: rgba(56,189,248,0.15); color: var(--cyan); border: 1px solid rgba(56,189,248,0.3); }
63
+ .badge-idle { background: rgba(74,85,104,0.2); color: var(--dim); border: 1px solid var(--border); }
64
+ .badge-done { background: rgba(34,197,94,0.15); color: var(--green); border: 1px solid rgba(34,197,94,0.3); }
65
+
66
+ .stat { display: flex; align-items: center; gap: 6px; color: var(--dim); }
67
+ .stat span { color: var(--text); }
68
+
69
+ .spacer { flex: 1; }
70
+
71
+ .ws-dot {
72
+ width: 8px; height: 8px; border-radius: 50%;
73
+ background: var(--dim);
74
+ transition: background 0.3s;
75
+ }
76
+ .ws-dot.connected { background: var(--green); box-shadow: 0 0 6px var(--green); }
77
+ .ws-dot.error { background: var(--red); }
78
+
79
+ /* Main grid area */
80
+ main {
81
+ flex: 1;
82
+ padding: 16px;
83
+ overflow: auto;
84
+ }
85
+
86
+ .empty-state {
87
+ display: flex;
88
+ flex-direction: column;
89
+ align-items: center;
90
+ justify-content: center;
91
+ height: 300px;
92
+ gap: 12px;
93
+ color: var(--dim);
94
+ font-size: 12px;
95
+ letter-spacing: 1px;
96
+ }
97
+
98
+ .empty-state .pulse {
99
+ width: 12px; height: 12px; border-radius: 50%;
100
+ background: var(--cyan);
101
+ animation: pulse 1.5s ease-in-out infinite;
102
+ }
103
+
104
+ @keyframes pulse {
105
+ 0%, 100% { opacity: 0.2; transform: scale(0.8); }
106
+ 50% { opacity: 1; transform: scale(1.2); }
107
+ }
108
+
109
+ /* Episode Grid */
110
+ .grid {
111
+ display: grid;
112
+ grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
113
+ gap: 12px;
114
+ }
115
+
116
+ /* Episode Cell */
117
+ .ep-cell {
118
+ background: var(--panel);
119
+ border: 1px solid var(--border);
120
+ border-radius: 6px;
121
+ overflow: hidden;
122
+ cursor: pointer;
123
+ transition: border-color 0.2s, transform 0.15s, opacity 0.3s;
124
+ animation: fadeIn 0.4s ease;
125
+ position: relative;
126
+ }
127
+
128
+ @keyframes fadeIn {
129
+ from { opacity: 0; transform: translateY(8px); }
130
+ to { opacity: 1; transform: translateY(0); }
131
+ }
132
+
133
+ .ep-cell:hover { border-color: var(--cyan); transform: translateY(-2px); }
134
+ .ep-cell.running { border-color: rgba(56,189,248,0.5); }
135
+ .ep-cell.done-good { border-color: rgba(34,197,94,0.5); }
136
+ .ep-cell.done-bad { border-color: rgba(239,68,68,0.4); }
137
+
138
+ /* Fullscreen */
139
+ .ep-cell.fullscreen {
140
+ position: fixed;
141
+ inset: 0;
142
+ z-index: 100;
143
+ border-radius: 0;
144
+ cursor: default;
145
+ display: grid;
146
+ grid-template-rows: auto 1fr auto;
147
+ animation: none;
148
+ transform: none;
149
+ }
150
+
151
+ .ep-header {
152
+ display: flex;
153
+ align-items: center;
154
+ gap: 8px;
155
+ padding: 8px 10px;
156
+ border-bottom: 1px solid var(--border);
157
+ }
158
+
159
+ .ep-id { font-size: 10px; color: var(--dim); letter-spacing: 1px; }
160
+
161
+ .status-badge {
162
+ padding: 2px 6px;
163
+ border-radius: 2px;
164
+ font-size: 9px;
165
+ letter-spacing: 1px;
166
+ font-weight: 700;
167
+ }
168
+
169
+ .status-running { background: rgba(56,189,248,0.2); color: var(--cyan); }
170
+ .status-done { background: rgba(34,197,94,0.2); color: var(--green); }
171
+ .status-error { background: rgba(239,68,68,0.2); color: var(--red); }
172
+ .status-timeout { background: rgba(245,158,11,0.2); color: var(--amber); }
173
+
174
+ .ep-canvas-wrap {
175
+ background: #080810;
176
+ display: flex;
177
+ align-items: center;
178
+ justify-content: center;
179
+ height: 200px;
180
+ overflow: hidden;
181
+ }
182
+
183
+ .ep-cell.fullscreen .ep-canvas-wrap { height: 100%; }
184
+
185
+ .ep-canvas { display: block; }
186
+
187
+ .ep-footer {
188
+ display: flex;
189
+ align-items: center;
190
+ gap: 10px;
191
+ padding: 6px 10px;
192
+ border-top: 1px solid var(--border);
193
+ color: var(--dim);
194
+ font-size: 10px;
195
+ }
196
+
197
+ .ep-metric { display: flex; flex-direction: column; align-items: center; gap: 2px; }
198
+ .ep-metric .m-label { font-size: 9px; color: var(--dim); }
199
+ .ep-metric .m-val { font-size: 11px; color: var(--text); font-weight: 600; }
200
+ .ep-metric .m-val.good { color: var(--green); }
201
+ .ep-metric .m-val.bad { color: var(--red); }
202
+
203
+ .ep-sep { width: 1px; height: 20px; background: var(--border); }
204
+
205
+ /* Fullscreen extras */
206
+ .ep-detail { display: none; }
207
+ .ep-cell.fullscreen .ep-detail {
208
+ display: block;
209
+ padding: 12px;
210
+ overflow: auto;
211
+ max-height: 200px;
212
+ border-top: 1px solid var(--border);
213
+ }
214
+
215
+ .back-btn {
216
+ display: none;
217
+ position: absolute;
218
+ top: 10px;
219
+ right: 10px;
220
+ padding: 4px 10px;
221
+ background: var(--border);
222
+ color: var(--text);
223
+ border: 1px solid var(--dim);
224
+ border-radius: 3px;
225
+ cursor: pointer;
226
+ font-family: var(--font);
227
+ font-size: 10px;
228
+ letter-spacing: 1px;
229
+ }
230
+
231
+ .ep-cell.fullscreen .back-btn { display: block; }
232
+ .back-btn:hover { background: var(--cyan); color: var(--bg); }
233
+
234
+ /* Fold history in fullscreen */
235
+ .fold-history { display: flex; flex-direction: column; gap: 4px; }
236
+ .fold-entry {
237
+ display: flex;
238
+ gap: 8px;
239
+ align-items: center;
240
+ color: var(--dim);
241
+ font-size: 10px;
242
+ }
243
+ .fold-entry .step-num { color: var(--cyan); min-width: 24px; }
244
+ .fold-type-badge {
245
+ padding: 1px 5px;
246
+ border-radius: 2px;
247
+ font-size: 9px;
248
+ font-weight: 700;
249
+ }
250
+ .fold-type-valley { background: rgba(56,189,248,0.2); color: var(--cyan); }
251
+ .fold-type-mountain { background: rgba(245,158,11,0.2); color: var(--amber); }
252
+ </style>
253
+ </head>
254
+ <body>
255
+
256
+ <header>
257
+ <div class="logo">OPTI<span class="accent">GAMI</span></div>
258
+ <div class="header-sep"></div>
259
+ <div id="trainBadge" class="badge badge-idle">IDLE</div>
260
+ <div class="header-sep"></div>
261
+ <div class="stat">BATCH <span id="batchNum">&#8212;</span></div>
262
+ <div class="stat">EPISODES <span id="epCount">0</span></div>
263
+ <div class="stat">AVG REWARD <span id="avgReward">&#8212;</span></div>
264
+ <div class="spacer"></div>
265
+ <div class="stat"><div id="wsDot" class="ws-dot"></div> WS</div>
266
+ </header>
267
+
268
+ <main id="main">
269
+ <div class="empty-state" id="emptyState">
270
+ <div class="pulse"></div>
271
+ WAITING FOR TRAINING...
272
+ </div>
273
+ <div class="grid" id="grid" style="display:none"></div>
274
+ </main>
275
+
276
+ <script>
277
+ const state = {
278
+ batchId: null,
279
+ episodes: {},
280
+ fullscreenId: null,
281
+ };
282
+
283
+ const renderers = {};
284
+
285
+ function connectWS() {
286
+ const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
287
+ const url = proto + '//' + location.host + '/ws/training';
288
+ const ws = new WebSocket(url);
289
+ const dot = document.getElementById('wsDot');
290
+
291
+ ws.onopen = function() { dot.className = 'ws-dot connected'; };
292
+ ws.onclose = function() {
293
+ dot.className = 'ws-dot error';
294
+ setTimeout(connectWS, 3000);
295
+ };
296
+ ws.onerror = function() { dot.className = 'ws-dot error'; };
297
+
298
+ ws.onmessage = function(e) {
299
+ try { handleMessage(JSON.parse(e.data)); }
300
+ catch (err) { console.error('WS parse error', err); }
301
+ };
302
+ }
303
+
304
+ function handleMessage(msg) {
305
+ switch (msg.type) {
306
+ case 'registry':
307
+ state.batchId = msg.batch_id;
308
+ state.episodes = {};
309
+ Object.entries(msg.episodes || {}).forEach(function(kv) {
310
+ state.episodes[kv[0]] = kv[1];
311
+ });
312
+ rebuildGrid();
313
+ updateHeader();
314
+ break;
315
+
316
+ case 'batch_start':
317
+ state.batchId = msg.batch_id;
318
+ state.episodes = {};
319
+ setTrainingBadge('TRAINING', 'badge-training');
320
+ rebuildGrid();
321
+ updateHeader();
322
+ break;
323
+
324
+ case 'batch_done':
325
+ setTrainingBadge('BATCH DONE', 'badge-done');
326
+ document.getElementById('avgReward').textContent =
327
+ msg.avg_score != null ? msg.avg_score.toFixed(2) : '\u2014';
328
+ break;
329
+
330
+ case 'training_done':
331
+ setTrainingBadge('DONE', 'badge-done');
332
+ break;
333
+
334
+ case 'episode_update': {
335
+ const id = msg.episode_id;
336
+ if (!state.episodes[id]) {
337
+ state.episodes[id] = { status: 'running', task: msg.task_name, step: 0, metrics: {}, fold_history: [] };
338
+ addEpisodeCell(id);
339
+ }
340
+ const ep = state.episodes[id];
341
+ ep.step = msg.step;
342
+ ep.status = 'running';
343
+ if (msg.observation) {
344
+ ep.metrics = msg.observation.metrics || {};
345
+ ep.fold_history = msg.observation.fold_history || [];
346
+ ep.paper_state = msg.observation.paper_state || {};
347
+ }
348
+ updateEpisodeCell(id);
349
+ if (msg.observation && msg.observation.paper_state) {
350
+ renderStep(id, msg.observation.paper_state);
351
+ }
352
+ break;
353
+ }
354
+
355
+ case 'episode_done': {
356
+ const id = msg.episode_id;
357
+ if (!state.episodes[id]) state.episodes[id] = {};
358
+ const ep = state.episodes[id];
359
+ ep.status = msg.status || 'done';
360
+ ep.score = msg.score;
361
+ ep.final_metrics = msg.final_metrics;
362
+ updateEpisodeCell(id);
363
+ break;
364
+ }
365
+ }
366
+
367
+ document.getElementById('epCount').textContent = Object.keys(state.episodes).length;
368
+ }
369
+
370
+ function rebuildGrid() {
371
+ const grid = document.getElementById('grid');
372
+ const empty = document.getElementById('emptyState');
373
+
374
+ Object.values(renderers).forEach(function(r) { if (r.raf) cancelAnimationFrame(r.raf); });
375
+ Object.keys(renderers).forEach(function(k) { delete renderers[k]; });
376
+
377
+ grid.textContent = '';
378
+
379
+ if (Object.keys(state.episodes).length === 0) {
380
+ empty.style.display = 'flex';
381
+ grid.style.display = 'none';
382
+ return;
383
+ }
384
+
385
+ empty.style.display = 'none';
386
+ grid.style.display = 'grid';
387
+
388
+ Object.keys(state.episodes).forEach(function(id) { addEpisodeCell(id); });
389
+ }
390
+
391
+ function makeEl(tag, props) {
392
+ const el = document.createElement(tag);
393
+ if (props) {
394
+ if (props.className) el.className = props.className;
395
+ if (props.id) el.id = props.id;
396
+ if (props.style) Object.assign(el.style, props.style);
397
+ if (props.textContent !== undefined) el.textContent = props.textContent;
398
+ if (props.dataset) Object.assign(el.dataset, props.dataset);
399
+ }
400
+ return el;
401
+ }
402
+
403
+ function addEpisodeCell(id) {
404
+ const grid = document.getElementById('grid');
405
+ const empty = document.getElementById('emptyState');
406
+
407
+ empty.style.display = 'none';
408
+ grid.style.display = 'grid';
409
+
410
+ if (document.getElementById('cell-' + id)) return;
411
+
412
+ const ep = state.episodes[id];
413
+
414
+ const cell = makeEl('div', { className: 'ep-cell running', id: 'cell-' + id, dataset: { epId: id } });
415
+
416
+ // Header
417
+ const header = makeEl('div', { className: 'ep-header' });
418
+ const epIdEl = makeEl('span', { className: 'ep-id', textContent: id });
419
+ const badgeEl = makeEl('span', { className: 'status-badge status-running', id: 'badge-' + id, textContent: 'RUNNING' });
420
+ const taskEl = makeEl('span', { id: 'task-' + id, textContent: (ep.task || '').toUpperCase() });
421
+ taskEl.style.marginLeft = 'auto';
422
+ taskEl.style.color = 'var(--dim)';
423
+ taskEl.style.fontSize = '9px';
424
+ header.appendChild(epIdEl);
425
+ header.appendChild(badgeEl);
426
+ header.appendChild(taskEl);
427
+ cell.appendChild(header);
428
+
429
+ // Canvas wrap
430
+ const canvasWrap = makeEl('div', { className: 'ep-canvas-wrap' });
431
+ const canvas = makeEl('canvas', { className: 'ep-canvas', id: 'canvas-' + id });
432
+ canvas.width = 240;
433
+ canvas.height = 180;
434
+ canvasWrap.appendChild(canvas);
435
+ cell.appendChild(canvasWrap);
436
+
437
+ // Footer
438
+ const footer = makeEl('div', { className: 'ep-footer' });
439
+
440
+ function makeMetric(labelText, valId) {
441
+ const metric = makeEl('div', { className: 'ep-metric' });
442
+ const label = makeEl('span', { className: 'm-label', textContent: labelText });
443
+ const val = makeEl('span', { className: 'm-val', id: valId, textContent: '\u2014' });
444
+ metric.appendChild(label);
445
+ metric.appendChild(val);
446
+ return metric;
447
+ }
448
+
449
+ const stepMetric = makeMetric('STEP', 'step-' + id);
450
+ document.getElementById('step-' + id) || stepMetric.querySelector('[id]');
451
+ const stepValEl = stepMetric.querySelector('.m-val');
452
+ if (stepValEl) stepValEl.textContent = '0';
453
+
454
+ footer.appendChild(stepMetric);
455
+ footer.appendChild(makeEl('div', { className: 'ep-sep' }));
456
+ footer.appendChild(makeMetric('COMPACT', 'compact-' + id));
457
+ footer.appendChild(makeEl('div', { className: 'ep-sep' }));
458
+ footer.appendChild(makeMetric('REWARD', 'reward-' + id));
459
+ footer.appendChild(makeEl('div', { className: 'ep-sep' }));
460
+ footer.appendChild(makeMetric('VALID', 'valid-' + id));
461
+ cell.appendChild(footer);
462
+
463
+ // Detail panel
464
+ const detail = makeEl('div', { className: 'ep-detail', id: 'detail-' + id });
465
+ const foldsContainer = makeEl('div', { className: 'fold-history', id: 'folds-' + id });
466
+ detail.appendChild(foldsContainer);
467
+ cell.appendChild(detail);
468
+
469
+ // Back button
470
+ const backBtn = makeEl('button', { className: 'back-btn', textContent: '\u2190 GRID' });
471
+ backBtn.addEventListener('click', function(e) { exitFullscreen(e); });
472
+ cell.appendChild(backBtn);
473
+
474
+ cell.addEventListener('click', function(e) {
475
+ if (e.target === backBtn) return;
476
+ enterFullscreen(id);
477
+ });
478
+
479
+ grid.appendChild(cell);
480
+
481
+ renderers[id] = {
482
+ canvas: canvas,
483
+ ctx: canvas.getContext('2d'),
484
+ lastVerts: null,
485
+ lastFaces: null,
486
+ lastStrain: null,
487
+ raf: null,
488
+ };
489
+
490
+ drawFlatSheet(id);
491
+ updateEpisodeCell(id);
492
+ }
493
+
494
+ function updateEpisodeCell(id) {
495
+ const ep = state.episodes[id];
496
+ if (!ep) return;
497
+ const cell = document.getElementById('cell-' + id);
498
+ if (!cell) return;
499
+
500
+ cell.className = 'ep-cell';
501
+ if (ep.status === 'running') {
502
+ cell.classList.add('running');
503
+ } else if (ep.status === 'done' && (ep.score || 0) > 5) {
504
+ cell.classList.add('done-good');
505
+ } else {
506
+ cell.classList.add('done-bad');
507
+ }
508
+ if (id === state.fullscreenId) cell.classList.add('fullscreen');
509
+
510
+ const badge = document.getElementById('badge-' + id);
511
+ if (badge) {
512
+ const cls = ep.status === 'running' ? 'status-running'
513
+ : ep.status === 'done' ? 'status-done'
514
+ : ep.status === 'error' ? 'status-error'
515
+ : 'status-timeout';
516
+ badge.className = 'status-badge ' + cls;
517
+ badge.textContent = ep.status.toUpperCase();
518
+ }
519
+
520
+ const m = ep.metrics || {};
521
+ const compact = m.compactness != null ? m.compactness.toFixed(2)
522
+ : (ep.final_metrics && ep.final_metrics.compactness != null ? ep.final_metrics.compactness.toFixed(2) : '\u2014');
523
+ const score = ep.score != null ? ep.score.toFixed(1) : '\u2014';
524
+ const valid = m.is_valid != null ? (m.is_valid ? '\u2713' : '\u2717') : '\u2014';
525
+
526
+ const stepEl = document.getElementById('step-' + id);
527
+ const compEl = document.getElementById('compact-' + id);
528
+ const rewEl = document.getElementById('reward-' + id);
529
+ const valEl = document.getElementById('valid-' + id);
530
+
531
+ if (stepEl) stepEl.textContent = ep.step != null ? ep.step : 0;
532
+ if (compEl) {
533
+ compEl.textContent = compact;
534
+ const val = parseFloat(compact);
535
+ compEl.className = 'm-val' + (isNaN(val) ? '' : val > 0.5 ? ' good' : val < 0.2 ? ' bad' : '');
536
+ }
537
+ if (rewEl) {
538
+ rewEl.textContent = score;
539
+ const val = parseFloat(score);
540
+ rewEl.className = 'm-val' + (isNaN(val) ? '' : val > 5 ? ' good' : val < 0 ? ' bad' : '');
541
+ }
542
+ if (valEl) {
543
+ valEl.textContent = valid;
544
+ valEl.className = 'm-val' + (valid === '\u2713' ? ' good' : valid === '\u2717' ? ' bad' : '');
545
+ }
546
+
547
+ updateFoldHistory(id);
548
+ }
549
+
550
+ function updateFoldHistory(id) {
551
+ const ep = state.episodes[id];
552
+ const container = document.getElementById('folds-' + id);
553
+ if (!container || !ep) return;
554
+ const history = ep.fold_history || [];
555
+
556
+ while (container.firstChild) container.removeChild(container.firstChild);
557
+
558
+ if (!history.length) {
559
+ const noFolds = makeEl('span', { textContent: 'NO FOLDS YET' });
560
+ noFolds.style.color = 'var(--dim)';
561
+ container.appendChild(noFolds);
562
+ return;
563
+ }
564
+
565
+ history.forEach(function(f, i) {
566
+ const type = f.type || 'valley';
567
+ const cls = type === 'mountain' ? 'fold-type-mountain' : 'fold-type-valley';
568
+ const startCoords = (f.line && f.line.start) ? f.line.start.map(function(n) { return n.toFixed(2); }).join(',') : '\u2014';
569
+ const endCoords = (f.line && f.line.end) ? f.line.end.map(function(n) { return n.toFixed(2); }).join(',') : '\u2014';
570
+
571
+ const entry = makeEl('div', { className: 'fold-entry' });
572
+
573
+ const stepNum = makeEl('span', { className: 'step-num', textContent: '#' + (i + 1) });
574
+ const typeBadge = makeEl('span', { className: 'fold-type-badge ' + cls, textContent: type.toUpperCase() });
575
+ const coords = makeEl('span', { textContent: '[' + startCoords + ']\u2192[' + endCoords + ']' });
576
+
577
+ entry.appendChild(stepNum);
578
+ entry.appendChild(typeBadge);
579
+ entry.appendChild(coords);
580
+ container.appendChild(entry);
581
+ });
582
+ }
583
+
584
+ function enterFullscreen(id) {
585
+ // Navigate to the full React UI with this episode loaded
586
+ window.location.href = `/?ep=${encodeURIComponent(id)}`;
587
+ return;
588
+ if (state.fullscreenId === id) return;
589
+ if (state.fullscreenId) exitFullscreen();
590
+ state.fullscreenId = id;
591
+ const cell = document.getElementById('cell-' + id);
592
+ if (cell) {
593
+ cell.classList.add('fullscreen');
594
+ const r = renderers[id];
595
+ if (r) {
596
+ r.canvas.width = Math.min(window.innerWidth * 0.7, 800);
597
+ r.canvas.height = Math.min(window.innerHeight * 0.6, 600);
598
+ if (r.lastVerts && r.lastFaces) {
599
+ drawMesh(id, r.lastVerts, r.lastFaces, r.lastStrain);
600
+ }
601
+ }
602
+ updateFoldHistory(id);
603
+ }
604
+ }
605
+
606
+ function exitFullscreen(e) {
607
+ if (e) e.stopPropagation();
608
+ if (!state.fullscreenId) return;
609
+ const cell = document.getElementById('cell-' + state.fullscreenId);
610
+ if (cell) {
611
+ cell.classList.remove('fullscreen');
612
+ const r = renderers[state.fullscreenId];
613
+ if (r) {
614
+ r.canvas.width = 240;
615
+ r.canvas.height = 180;
616
+ if (r.lastVerts && r.lastFaces) {
617
+ drawMesh(state.fullscreenId, r.lastVerts, r.lastFaces, r.lastStrain);
618
+ } else {
619
+ drawFlatSheet(state.fullscreenId);
620
+ }
621
+ }
622
+ }
623
+ state.fullscreenId = null;
624
+ }
625
+
626
+ const LIGHT = normalize3([0.4, -0.45, 1.0]);
627
+ const PAPER_COLOR = [250, 250, 245];
628
+
629
+ function normalize3(v) {
630
+ const m = Math.hypot(v[0], v[1], v[2]);
631
+ return m < 1e-12 ? [0,0,0] : [v[0]/m, v[1]/m, v[2]/m];
632
+ }
633
+
634
+ function cross3(a, b) {
635
+ return [a[1]*b[2]-a[2]*b[1], a[2]*b[0]-a[0]*b[2], a[0]*b[1]-a[1]*b[0]];
636
+ }
637
+
638
+ function dot3(a, b) { return a[0]*b[0]+a[1]*b[1]+a[2]*b[2]; }
639
+
640
+ function sub3(a, b) { return [a[0]-b[0], a[1]-b[1], a[2]-b[2]]; }
641
+
642
+ function projectVert(v, cx, cy, scale) {
643
+ var x = v[0] - 0.5;
644
+ var y = v[1] - 0.5;
645
+ var z = v[2] || 0;
646
+
647
+ var pitch = 0.62, yaw = -0.52;
648
+ var cp = Math.cos(pitch), sp = Math.sin(pitch);
649
+ var y1 = y*cp - z*sp;
650
+ var z1 = y*sp + z*cp;
651
+ var cy2 = Math.cos(yaw), sy = Math.sin(yaw);
652
+ var x2 = x*cy2 + z1*sy;
653
+ var z2 = -x*sy + z1*cy2;
654
+
655
+ var camDist = 2.8;
656
+ var persp = camDist / (camDist - z2);
657
+ return { x: cx + x2 * persp * scale, y: cy - y1 * persp * scale, z: z2 };
658
+ }
659
+
660
+ function strainColor(s) {
661
+ var t = Math.min(Math.max(s || 0, 0), 0.2) / 0.2;
662
+ var r = Math.round(50 + t * 200);
663
+ var g = Math.round(250 - t * 200);
664
+ var b = Math.round(245 - t * 200);
665
+ return 'rgb(' + r + ',' + g + ',' + b + ')';
666
+ }
667
+
668
+ function renderStep(id, paperState) {
669
+ if (!paperState) return;
670
+ var verts = paperState.vertices_coords;
671
+ var faces = paperState.faces_vertices;
672
+ var strain = paperState.strain_per_vertex;
673
+ if (!verts || !faces) return;
674
+ drawMesh(id, verts, faces, strain);
675
+ }
676
+
677
+ function drawMesh(id, verts, faces, strain) {
678
+ var r = renderers[id];
679
+ if (!r) return;
680
+ r.lastVerts = verts;
681
+ r.lastFaces = faces;
682
+ r.lastStrain = strain;
683
+
684
+ var canvas = r.canvas, ctx = r.ctx;
685
+ var w = canvas.width, h = canvas.height;
686
+ var scale = Math.min(w, h) * 0.8;
687
+ var cx = w * 0.5, cy = h * 0.52;
688
+
689
+ ctx.clearRect(0, 0, w, h);
690
+ ctx.fillStyle = '#080810';
691
+ ctx.fillRect(0, 0, w, h);
692
+
693
+ var projected = verts.map(function(v) { return projectVert(v, cx, cy, scale); });
694
+
695
+ var tris = faces.map(function(face) {
696
+ var idxs = face.length > 3
697
+ ? [face[0], face[1], face[2], face[0], face[2], face[3] || face[2]]
698
+ : face;
699
+ var a = idxs[0], b = idxs[1], c = idxs[2];
700
+ var p0 = projected[a], p1 = projected[b], p2 = projected[c];
701
+ var avgZ = (p0.z + p1.z + p2.z) / 3;
702
+ var v0 = verts[a] || [0,0,0], v1 = verts[b] || [0,0,0], v2 = verts[c] || [0,0,0];
703
+ var norm = normalize3(cross3(sub3(v1,v0), sub3(v2,v0)));
704
+ var intensity = Math.abs(dot3(norm, LIGHT));
705
+ var avgStrain = strain ? (((strain[a]||0) + (strain[b]||0) + (strain[c]||0)) / 3) : 0;
706
+ return { face: [a,b,c], avgZ: avgZ, intensity: intensity, avgStrain: avgStrain };
707
+ });
708
+
709
+ tris.sort(function(a, b) { return a.avgZ - b.avgZ; });
710
+
711
+ for (var i = 0; i < tris.length; i++) {
712
+ var tri = tris[i];
713
+ var a = tri.face[0], b = tri.face[1], c = tri.face[2];
714
+ var p0 = projected[a], p1 = projected[b], p2 = projected[c];
715
+ if (!p0 || !p1 || !p2) continue;
716
+
717
+ var lit = Math.min(Math.max(0.3 + 0.7 * tri.intensity, 0), 1);
718
+ var fillColor;
719
+ if (tri.avgStrain > 0.005) {
720
+ fillColor = strainColor(tri.avgStrain);
721
+ } else {
722
+ var rv = Math.round(PAPER_COLOR[0] * lit);
723
+ var gv = Math.round(PAPER_COLOR[1] * lit);
724
+ var bv = Math.round(PAPER_COLOR[2] * lit);
725
+ fillColor = 'rgb(' + rv + ',' + gv + ',' + bv + ')';
726
+ }
727
+
728
+ ctx.beginPath();
729
+ ctx.moveTo(p0.x, p0.y);
730
+ ctx.lineTo(p1.x, p1.y);
731
+ ctx.lineTo(p2.x, p2.y);
732
+ ctx.closePath();
733
+ ctx.fillStyle = fillColor;
734
+ ctx.fill();
735
+ ctx.strokeStyle = 'rgba(42,42,58,0.3)';
736
+ ctx.lineWidth = 0.5;
737
+ ctx.stroke();
738
+ }
739
+ }
740
+
741
+ function drawFlatSheet(id) {
742
+ var flatVerts = [[0,0,0],[1,0,0],[1,1,0],[0,1,0]];
743
+ var flatFaces = [[0,1,2],[0,2,3]];
744
+ drawMesh(id, flatVerts, flatFaces, null);
745
+ }
746
+
747
+ function setTrainingBadge(label, cls) {
748
+ var b = document.getElementById('trainBadge');
749
+ b.textContent = label;
750
+ b.className = 'badge ' + cls;
751
+ }
752
+
753
+ function updateHeader() {
754
+ document.getElementById('batchNum').textContent = state.batchId != null ? state.batchId : '\u2014';
755
+ document.getElementById('epCount').textContent = Object.keys(state.episodes).length;
756
+ }
757
+
758
+ connectWS();
759
+ </script>
760
+ </body>
761
+ </html>