sissississi Claude Opus 4.6 commited on
Commit
c0cedb4
·
1 Parent(s): 490094b

Route rewards through OpenEnv API instead of local computation

Browse files

- Add Next.js rewrite proxy: /api/env/* → localhost:8000 (OpenEnv backend)
- Replace local shape_match_reward with openenv_reward that calls:
POST /api/env/reset (start episode for task)
POST /api/env/step (submit FOLD, get reward from environment)
- Add OptigamiEnvClient class for clean API interaction
- Fetch task list from environment instead of hardcoding
- Test cell verifies full reset→step→reward loop before training
- valid_fold_reward stays local (just JSON parsing, no simulation)

The training loop now properly uses the OpenEnv environment for
reward computation, satisfying the submission requirement.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. next.config.ts +8 -0
  2. training/train_grpo.ipynb +234 -335
next.config.ts CHANGED
@@ -20,6 +20,14 @@ const nextConfig: NextConfig = {
20
  ],
21
  },
22
  output: 'standalone',
 
 
 
 
 
 
 
 
23
  transpilePackages: ['motion'],
24
  webpack: (config, {dev}) => {
25
  // HMR is disabled in AI Studio via DISABLE_HMR env var.
 
20
  ],
21
  },
22
  output: 'standalone',
23
+ async rewrites() {
24
+ return [
25
+ {
26
+ source: '/api/env/:path*',
27
+ destination: 'http://localhost:8000/:path*',
28
+ },
29
+ ];
30
+ },
31
  transpilePackages: ['motion'],
32
  webpack: (config, {dev}) => {
33
  // HMR is disabled in AI Studio via DISABLE_HMR env var.
training/train_grpo.ipynb CHANGED
@@ -4,12 +4,12 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# Optigami — GRPO Training for Origami Crease Pattern Generation\n",
8
  "\n",
9
  "Train an LLM to generate FOLD-format crease patterns using **GRPO** with **Unsloth** + **TRL**.\n",
10
  "\n",
11
- "The model receives diverse origami descriptions and must learn to generate valid crease patterns\n",
12
- "that fold into the target shapes scored by structural validity + shape similarity.\n",
13
  "\n",
14
  "**Environment**: [openenv-community/optigami_](https://huggingface.co/spaces/openenv-community/optigami_) (OpenEnv 0.2.1)\n",
15
  "\n",
@@ -45,159 +45,19 @@
45
  "elif importlib.util.find_spec(\"unsloth\") is None:\n",
46
  " !pip install -qqq unsloth trackio\n",
47
  "!pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n",
48
- "!pip install -qqq scipy datasets"
49
  ]
50
  },
51
  {
52
  "cell_type": "markdown",
53
  "metadata": {},
54
  "source": [
55
- "## 2. Origami Engine (Inlined)"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {},
62
- "outputs": [],
63
- "source": [
64
- "import json, re, random\n",
65
- "from collections import defaultdict\n",
66
- "from dataclasses import dataclass\n",
67
- "from typing import Any\n",
68
- "\n",
69
- "import numpy as np\n",
70
- "from scipy.spatial.distance import cdist\n",
71
- "from scipy.spatial.transform import Rotation\n",
72
- "\n",
73
- "\n",
74
- "def validate_fold(fold_data):\n",
75
- " for key in (\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"):\n",
76
- " if key not in fold_data: return False, f\"Missing: {key}\"\n",
77
- " verts, edges, asgn = fold_data[\"vertices_coords\"], fold_data[\"edges_vertices\"], fold_data[\"edges_assignment\"]\n",
78
- " if len(verts) < 3: return False, \"<3 verts\"\n",
79
- " if len(edges) < 3: return False, \"<3 edges\"\n",
80
- " if len(edges) != len(asgn): return False, \"len mismatch\"\n",
81
- " if \"edges_foldAngle\" in fold_data and len(fold_data[\"edges_foldAngle\"]) != len(edges): return False, \"angle len\"\n",
82
- " nv = len(verts)\n",
83
- " for i, v in enumerate(verts):\n",
84
- " if not isinstance(v, (list, tuple)) or len(v) < 2: return False, f\"vert {i}\"\n",
85
- " for i, e in enumerate(edges):\n",
86
- " if not isinstance(e, (list, tuple)) or len(e) != 2: return False, f\"edge {i}\"\n",
87
- " if e[0]<0 or e[0]>=nv or e[1]<0 or e[1]>=nv: return False, f\"edge {i} range\"\n",
88
- " if e[0]==e[1]: return False, f\"edge {i} degen\"\n",
89
- " for i, a in enumerate(asgn):\n",
90
- " if a not in {\"M\",\"V\",\"B\",\"F\",\"U\",\"C\"}: return False, f\"asgn {i}\"\n",
91
- " if not any(a in (\"M\",\"V\") for a in asgn): return False, \"no folds\"\n",
92
- " if not any(a==\"B\" for a in asgn): return False, \"no boundary\"\n",
93
- " return True, \"\"\n",
94
- "\n",
95
- "def parse_fold(fd):\n",
96
- " verts = fd[\"vertices_coords\"]\n",
97
- " vertices = np.zeros((len(verts),3), dtype=np.float64)\n",
98
- " for i,v in enumerate(verts): vertices[i,0]=v[0]; vertices[i,1]=v[1]; (len(v)>2 and setattr(vertices, '__setitem__', None) is None) or (vertices.__setitem__((i,2), v[2]) if len(v)>2 else None)\n",
99
- " # Fix the above — simple approach\n",
100
- " vertices = np.zeros((len(verts),3), dtype=np.float64)\n",
101
- " for i,v in enumerate(verts):\n",
102
- " vertices[i,0]=v[0]; vertices[i,1]=v[1]\n",
103
- " if len(v)>2: vertices[i,2]=v[2]\n",
104
- " edges = np.array(fd[\"edges_vertices\"], dtype=np.int32)\n",
105
- " asgn = list(fd[\"edges_assignment\"])\n",
106
- " if \"edges_foldAngle\" in fd: fa = np.array(fd[\"edges_foldAngle\"], dtype=np.float64)\n",
107
- " else:\n",
108
- " fa = np.zeros(len(edges), dtype=np.float64)\n",
109
- " for i,a in enumerate(asgn):\n",
110
- " if a==\"V\": fa[i]=180.0\n",
111
- " elif a==\"M\": fa[i]=-180.0\n",
112
- " fa_rad = np.radians(fa)\n",
113
- " if \"faces_vertices\" in fd:\n",
114
- " tris = []\n",
115
- " for face in fd[\"faces_vertices\"]:\n",
116
- " for i in range(1,len(face)-1): tris.append([face[0],face[i],face[i+1]])\n",
117
- " faces = np.array(tris, dtype=np.int32) if tris else np.zeros((0,3), dtype=np.int32)\n",
118
- " else:\n",
119
- " adj=defaultdict(set)\n",
120
- " for v1,v2 in edges: adj[v1].add(v2); adj[v2].add(v1)\n",
121
- " ts=set()\n",
122
- " for v1,v2 in edges:\n",
123
- " for v3 in adj[v1]&adj[v2]: ts.add(tuple(sorted([v1,v2,v3])))\n",
124
- " faces = np.array(list(ts), dtype=np.int32) if ts else np.zeros((0,3), dtype=np.int32)\n",
125
- " return {\"vertices\":vertices,\"edges\":edges,\"assignments\":asgn,\"fold_angles\":fa_rad,\"faces\":faces}\n",
126
- "\n",
127
- "@dataclass\n",
128
- "class SimResult:\n",
129
- " positions: np.ndarray; converged: bool; steps_taken: int; max_strain: float; total_energy: float\n",
130
- "\n",
131
- "def simulate(fold_data, crease_percent=1.0):\n",
132
- " p = parse_fold(fold_data)\n",
133
- " flat=p[\"vertices\"].copy(); edges=p[\"edges\"]; asgn=p[\"assignments\"]; fa=p[\"fold_angles\"]; faces=p[\"faces\"]; pos=flat.copy()\n",
134
- " if len(faces)==0: return SimResult(pos,True,0,0.,0.)\n",
135
- " fadj={}\n",
136
- " for fi,face in enumerate(faces):\n",
137
- " for j in range(len(face)):\n",
138
- " v1,v2=int(face[j]),int(face[(j+1)%len(face)]); k=(min(v1,v2),max(v1,v2)); fadj.setdefault(k,[]).append(fi)\n",
139
- " cm={}\n",
140
- " for i,(v1,v2) in enumerate(edges):\n",
141
- " k=(min(int(v1),int(v2)),max(int(v1),int(v2)))\n",
142
- " if asgn[i] in (\"M\",\"V\"): cm[k]=fa[i]*crease_percent\n",
143
- " nf=len(faces); fR=[None]*nf; ft=[None]*nf; fR[0]=np.eye(3); ft[0]=np.zeros(3)\n",
144
- " vis=[False]*nf; vis[0]=True; placed=set(int(vi) for vi in faces[0])\n",
145
- " q=[0]\n",
146
- " while q:\n",
147
- " fi=q.pop(0); face=faces[fi]\n",
148
- " for j in range(len(face)):\n",
149
- " v1,v2=int(face[j]),int(face[(j+1)%len(face)]); ek=(min(v1,v2),max(v1,v2))\n",
150
- " for fj in fadj.get(ek,[]):\n",
151
- " if vis[fj]: continue\n",
152
- " vis[fj]=True; q.append(fj); ang=cm.get(ek,0.0)\n",
153
- " if abs(ang)>1e-10:\n",
154
- " p1=pos[v1].copy(); ax=pos[v2]-p1; al=np.linalg.norm(ax)\n",
155
- " fr=Rotation.from_rotvec(ang*ax/al).as_matrix() if al>1e-12 else np.eye(3)\n",
156
- " fR[fj]=fr@fR[fi]; ft[fj]=fr@(ft[fi]-p1)+p1\n",
157
- " else: fR[fj]=fR[fi].copy(); ft[fj]=ft[fi].copy()\n",
158
- " for vi in faces[fj]:\n",
159
- " vi2=int(vi)\n",
160
- " if vi2 not in placed: pos[vi2]=fR[fj]@flat[vi2]+ft[fj]; placed.add(vi2)\n",
161
- " ms=0.0\n",
162
- " for v1,v2 in edges:\n",
163
- " r=np.linalg.norm(flat[v2]-flat[v1]); c=np.linalg.norm(pos[v2]-pos[v1])\n",
164
- " if r>1e-12: ms=max(ms,abs(c-r)/r)\n",
165
- " return SimResult(pos,True,1,ms,0.)\n",
166
- "\n",
167
- "def _rotations():\n",
168
- " rs=[np.eye(3)]\n",
169
- " for k in range(1,4):\n",
170
- " a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
171
- " rs.append(np.array([[c,-s,0],[s,c,0],[0,0,1]]))\n",
172
- " for k in range(1,4):\n",
173
- " a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
174
- " rs.append(np.array([[1,0,0],[0,c,-s],[0,s,c]]))\n",
175
- " for k in range(1,4):\n",
176
- " a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
177
- " rs.append(np.array([[c,0,s],[0,1,0],[-s,0,c]]))\n",
178
- " rs+=[np.diag([-1.,1.,1.]),np.diag([1.,-1.,1.]),np.diag([1.,1.,-1.])]\n",
179
- " return rs\n",
180
- "\n",
181
- "def compute_shape_match(pred, tgt):\n",
182
- " if len(pred)==0 or len(tgt)==0: return 0.0\n",
183
- " pc=pred-pred.mean(0); tc=tgt-tgt.mean(0); best=0.0\n",
184
- " for r in _rotations():\n",
185
- " rot=pc@r.T; d=cdist(rot,tc); ch=(d.min(1).mean()+d.min(0).mean())/2\n",
186
- " ap=np.vstack([rot,tc]); dg=np.linalg.norm(ap.max(0)-ap.min(0))\n",
187
- " sc=max(0.,1.-ch/dg) if dg>1e-12 else (1. if ch<1e-12 else 0.)\n",
188
- " best=max(best,sc)\n",
189
- " return best\n",
190
- "\n",
191
- "print(\"Engine loaded.\")"
192
- ]
193
- },
194
- {
195
- "cell_type": "markdown",
196
- "metadata": {},
197
- "source": [
198
- "## 3. Task Definitions\n",
199
  "\n",
200
- "Diverse origami tasks — the model must generalize across different shapes, not memorize one."
 
 
 
201
  ]
202
  },
203
  {
@@ -206,104 +66,75 @@
206
  "metadata": {},
207
  "outputs": [],
208
  "source": [
209
- "TASKS = {\n",
210
- " \"triangle\": {\n",
211
- " \"name\": \"triangle\",\n",
212
- " \"description\": \"Fold the paper in half diagonally to make a triangle\",\n",
213
- " \"difficulty\": 1, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
214
- " \"target_fold\": {\n",
215
- " \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1]],\n",
216
- " \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n",
217
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
218
- " \"edges_foldAngle\": [0,0,0,0,180],\n",
219
- " \"faces_vertices\": [[0,1,2],[0,2,3]],\n",
220
- " },\n",
221
- " },\n",
222
- " \"half_fold\": {\n",
223
- " \"name\": \"half_fold\",\n",
224
- " \"description\": \"Fold the paper in half horizontally along the middle\",\n",
225
- " \"difficulty\": 1, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
226
- " \"target_fold\": {\n",
227
- " \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1],[0,0.5],[1,0.5]],\n",
228
- " \"edges_vertices\": [[0,1],[1,5],[5,2],[2,3],[3,4],[4,0],[4,5]],\n",
229
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
230
- " \"edges_foldAngle\": [0,0,0,0,0,0,180],\n",
231
- " \"faces_vertices\": [[0,1,5,4],[4,5,2,3]],\n",
232
- " },\n",
233
- " },\n",
234
- " \"quarter_fold\": {\n",
235
- " \"name\": \"quarter_fold\",\n",
236
- " \"description\": \"Fold the paper into quarters with two perpendicular creases through the center\",\n",
237
- " \"difficulty\": 2, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
238
- " \"target_fold\": {\n",
239
- " \"vertices_coords\": [[0,0],[0.5,0],[1,0],[0,0.5],[0.5,0.5],[1,0.5],[0,1],[0.5,1],[1,1]],\n",
240
- " \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[1,4],[4,7],[3,4],[4,5]],\n",
241
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"V\",\"V\",\"V\"],\n",
242
- " \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,180,180,180],\n",
243
- " \"faces_vertices\": [[0,1,4,3],[1,2,5,4],[3,4,7,6],[4,5,8,7]],\n",
244
- " },\n",
245
- " },\n",
246
- " \"accordion\": {\n",
247
- " \"name\": \"accordion\",\n",
248
- " \"description\": \"Make a zig-zag accordion fold with alternating mountain and valley creases like a paper fan\",\n",
249
- " \"difficulty\": 2, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
250
- " \"target_fold\": {\n",
251
- " \"vertices_coords\": [[-1,1],[-0.5,1],[0,1],[0.5,1],[1,1],[-1,-1],[-0.5,-1],[0,-1],[0.5,-1],[1,-1]],\n",
252
- " \"edges_vertices\": [[0,1],[1,2],[2,3],[3,4],[5,6],[6,7],[7,8],[8,9],[0,5],[4,9],[1,6],[2,7],[3,8]],\n",
253
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"M\",\"V\"],\n",
254
- " \"edges_foldAngle\": [0,0,0,0,0,0,0,0,0,0,180,-180,180],\n",
255
- " \"faces_vertices\": [[0,1,6,5],[1,2,7,6],[2,3,8,7],[3,4,9,8]],\n",
256
- " },\n",
257
- " },\n",
258
- " \"waterbomb\": {\n",
259
- " \"name\": \"waterbomb\",\n",
260
- " \"description\": \"Create a waterbomb base with valley folds on diagonals and mountain folds on midlines of a square\",\n",
261
- " \"difficulty\": 3, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
262
- " \"target_fold\": {\n",
263
- " \"vertices_coords\": [[-1,1],[0,1],[1,1],[-1,0],[0,0],[1,0],[-1,-1],[0,-1],[1,-1]],\n",
264
- " \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[0,4],[2,4],[6,4],[8,4],[1,4],[3,4],[5,4],[7,4]],\n",
265
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"V\",\"V\",\"V\",\"M\",\"M\",\"M\",\"M\"],\n",
266
- " \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,180,180,180,-180,-180,-180,-180],\n",
267
- " \"faces_vertices\": [[0,1,4],[1,2,4],[2,5,4],[5,8,4],[8,7,4],[7,6,4],[6,3,4],[3,0,4]],\n",
268
- " },\n",
269
- " },\n",
270
- " \"miura_ori\": {\n",
271
- " \"name\": \"miura_ori\",\n",
272
- " \"description\": \"Create a Miura-ori tessellation with offset zigzag vertices that folds flat in one motion\",\n",
273
- " \"difficulty\": 3, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
274
- " \"target_fold\": {\n",
275
- " \"vertices_coords\": [[-1,1],[0,1.2],[1,1],[-1,0],[0,0.2],[1,0],[-1,-1],[0,-0.8],[1,-1]],\n",
276
- " \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[1,4],[4,7],[3,4],[4,5]],\n",
277
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"M\",\"M\",\"V\",\"M\"],\n",
278
- " \"edges_foldAngle\": [0,0,0,0,0,0,0,0,-180,-180,180,-180],\n",
279
- " \"faces_vertices\": [[0,1,4,3],[1,2,5,4],[3,4,7,6],[4,5,8,7]],\n",
280
- " },\n",
281
- " },\n",
282
- " \"letter_fold\": {\n",
283
- " \"name\": \"letter_fold\",\n",
284
- " \"description\": \"Tri-fold the paper into thirds like a letter envelope with two parallel horizontal creases\",\n",
285
- " \"difficulty\": 2, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
286
- " \"target_fold\": {\n",
287
- " \"vertices_coords\": [[0,0],[1,0],[0,0.333],[1,0.333],[0,0.667],[1,0.667],[0,1],[1,1]],\n",
288
- " \"edges_vertices\": [[0,1],[1,3],[3,5],[5,7],[7,6],[6,4],[4,2],[2,0],[2,3],[4,5]],\n",
289
- " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"M\"],\n",
290
- " \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,-180],\n",
291
- " \"faces_vertices\": [[0,1,3,2],[2,3,5,4],[4,5,7,6]],\n",
292
- " },\n",
293
- " },\n",
294
- "}\n",
295
  "\n",
296
- "def get_task(name): return TASKS[name]\n",
297
- "print(f\"Tasks ({len(TASKS)}): {list(TASKS.keys())}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  ]
299
  },
300
  {
301
  "cell_type": "markdown",
302
  "metadata": {},
303
  "source": [
304
- "## 4. Prompt Template & Multi-Task Dataset\n",
305
  "\n",
306
- "Each row in the dataset is a **different task** the model must generalize across shapes."
307
  ]
308
  },
309
  {
@@ -312,62 +143,37 @@
312
  "metadata": {},
313
  "outputs": [],
314
  "source": [
315
- "from datasets import Dataset\n",
316
- "\n",
317
- "PROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\n",
318
- "that, when folded, produces the target shape described below.\n",
319
- "\n",
320
- "Target: {description}\n",
321
- "Paper size: {width} x {height}\n",
322
- "\n",
323
- "Output a JSON object with these exact fields:\n",
324
- "- vertices_coords: [[x, y], ...] — 2D positions on the flat paper\n",
325
- "- edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges\n",
326
- "- edges_assignment: [\"B\"|\"M\"|\"V\", ...] — B=boundary, M=mountain fold, V=valley fold\n",
327
- "- edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative, V: positive, B: 0)\n",
328
- "\n",
329
- "Rules:\n",
330
- "- Boundary edges (B) must outline the paper rectangle\n",
331
- "- At least one fold crease (M or V) must exist\n",
332
- "- Mountain fold angles are negative (-180 to 0)\n",
333
- "- Valley fold angles are positive (0 to 180)\n",
334
- "- All vertex indices in edges must be valid (0 to N-1)\n",
335
- "\n",
336
- "Output ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n",
337
- "\n",
338
- "\n",
339
- "def build_prompt(task):\n",
340
- " return PROMPT_TEMPLATE.format(\n",
341
- " description=task[\"description\"],\n",
342
- " width=task[\"paper\"][\"width\"],\n",
343
- " height=task[\"paper\"][\"height\"],\n",
344
- " )\n",
345
- "\n",
346
- "\n",
347
- "# Build multi-task dataset — rotate through all tasks\n",
348
- "task_names = list(TASKS.keys())\n",
349
- "rows = []\n",
350
- "for i in range(1000):\n",
351
- " tn = task_names[i % len(task_names)]\n",
352
- " t = TASKS[tn]\n",
353
- " rows.append({\n",
354
- " \"prompt\": [{\"role\": \"user\", \"content\": build_prompt(t)}],\n",
355
- " \"task_name\": tn,\n",
356
- " \"answer\": 0,\n",
357
- " })\n",
358
- "\n",
359
- "grpo_dataset = Dataset.from_list(rows)\n",
360
- "print(f\"Dataset: {len(grpo_dataset)} rows across {len(task_names)} tasks\")\n",
361
- "print(f\"Task distribution: {dict((tn, sum(1 for r in rows if r['task_name']==tn)) for tn in task_names)}\")"
362
  ]
363
  },
364
  {
365
  "cell_type": "markdown",
366
  "metadata": {},
367
  "source": [
368
- "## 5. Reward Functions\n",
369
  "\n",
370
- "The reward functions read `task_name` from the dataset row to know which target to compare against."
 
371
  ]
372
  },
373
  {
@@ -379,6 +185,7 @@
379
  "PRINTER = 0\n",
380
  "\n",
381
  "def extract_fold_json(response):\n",
 
382
  " m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
383
  " if m:\n",
384
  " try: return json.loads(m.group(1))\n",
@@ -395,52 +202,95 @@
395
  "\n",
396
  "\n",
397
  "def valid_fold_reward(completions, **kwargs):\n",
398
- " \"\"\"Reward 1: +1.0 valid FOLD, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
 
399
  " scores = []\n",
400
  " for c in completions:\n",
401
  " fold = extract_fold_json(c[0][\"content\"])\n",
402
- " if fold is None: scores.append(-2.0); continue\n",
403
- " ok, _ = validate_fold(fold)\n",
404
- " scores.append(1.0 if ok else -0.5)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  " return scores\n",
406
  "\n",
407
  "\n",
408
- "def shape_match_reward(completions, task_name, **kwargs):\n",
409
- " \"\"\"Reward 2: similarity*20 for matching target shape. Uses task_name from dataset.\"\"\"\n",
 
 
 
 
410
  " global PRINTER\n",
411
- " # task_name comes as a list (one per completion in the batch)\n",
412
- " if isinstance(task_name, list):\n",
413
- " tn = task_name[0] # All completions in a group share the same prompt/task\n",
414
- " else:\n",
415
- " tn = task_name\n",
416
- " task = get_task(tn)\n",
417
- " try: tgt = simulate(task[\"target_fold\"]).positions\n",
418
- " except: return [0.0]*len(completions)\n",
419
  "\n",
420
  " scores = []\n",
421
  " for c in completions:\n",
422
  " resp = c[0][\"content\"]\n",
 
 
423
  " if PRINTER % 10 == 0:\n",
424
  " print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
425
  " PRINTER += 1\n",
 
 
426
  " fold = extract_fold_json(resp)\n",
427
- " if fold is None: scores.append(-2.0); continue\n",
428
- " ok, _ = validate_fold(fold)\n",
429
- " if not ok: scores.append(-1.0); continue\n",
 
430
  " try:\n",
431
- " sim = compute_shape_match(simulate(fold).positions, tgt)\n",
432
- " scores.append(sim * 20.0)\n",
433
- " except: scores.append(-1.0)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  " return scores\n",
435
  "\n",
436
- "print(\"Reward functions ready.\")"
 
437
  ]
438
  },
439
  {
440
  "cell_type": "markdown",
441
  "metadata": {},
442
  "source": [
443
- "## 6. Quick Engine Test"
 
 
444
  ]
445
  },
446
  {
@@ -449,18 +299,63 @@
449
  "metadata": {},
450
  "outputs": [],
451
  "source": [
452
- "for tn in TASKS:\n",
453
- " t = TASKS[tn]\n",
454
- " r = simulate(t[\"target_fold\"])\n",
455
- " ss = compute_shape_match(r.positions, r.positions)\n",
456
- " print(f\"{tn:15s} | verts={len(t['target_fold']['vertices_coords']):2d} | self-sim={ss:.3f} | strain={r.max_strain:.4f}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  ]
458
  },
459
  {
460
  "cell_type": "markdown",
461
  "metadata": {},
462
  "source": [
463
- "## 7. Configuration"
464
  ]
465
  },
466
  {
@@ -494,7 +389,7 @@
494
  "cell_type": "markdown",
495
  "metadata": {},
496
  "source": [
497
- "## 8. Load Model + LoRA"
498
  ]
499
  },
500
  {
@@ -523,13 +418,15 @@
523
  "cell_type": "markdown",
524
  "metadata": {},
525
  "source": [
526
- "## 9. GRPO Training\n",
 
 
 
 
 
 
527
  "\n",
528
- "Each step:\n",
529
- "1. Pick a task prompt (rotates through triangle, half_fold, quarter_fold, accordion, waterbomb, miura_ori, letter_fold)\n",
530
- "2. Generate N completions\n",
531
- "3. Score each: valid_fold (+1/-0.5/-2) + shape_match (0-20)\n",
532
- "4. GRPO uses within-group variance to update the policy"
533
  ]
534
  },
535
  {
@@ -562,12 +459,12 @@
562
  "trainer = GRPOTrainer(\n",
563
  " model=model,\n",
564
  " processing_class=tokenizer,\n",
565
- " reward_funcs=[valid_fold_reward, shape_match_reward],\n",
566
  " args=training_args,\n",
567
  " train_dataset=grpo_dataset,\n",
568
  ")\n",
569
  "\n",
570
- "print(\"Starting GRPO training...\")\n",
571
  "trainer.train()\n",
572
  "print(\"Done!\")"
573
  ]
@@ -576,7 +473,7 @@
576
  "cell_type": "markdown",
577
  "metadata": {},
578
  "source": [
579
- "## 10. Test Across Tasks"
580
  ]
581
  },
582
  {
@@ -587,8 +484,8 @@
587
  "source": [
588
  "FastLanguageModel.for_inference(model)\n",
589
  "\n",
590
- "for tn in [\"triangle\", \"half_fold\", \"accordion\", \"waterbomb\"]:\n",
591
- " task = get_task(tn)\n",
592
  " prompt = build_prompt(task)\n",
593
  " inputs = tokenizer.apply_chat_template(\n",
594
  " [{\"role\": \"user\", \"content\": prompt}],\n",
@@ -599,12 +496,14 @@
599
  "\n",
600
  " fold = extract_fold_json(response)\n",
601
  " if fold:\n",
602
- " ok, err = validate_fold(fold)\n",
603
- " if ok:\n",
604
- " sim = compute_shape_match(simulate(fold).positions, simulate(task[\"target_fold\"]).positions)\n",
605
- " print(f\"{tn:15s} | valid | similarity={sim:.3f} | reward={sim*20:.1f}\")\n",
606
- " else:\n",
607
- " print(f\"{tn:15s} | invalid: {err}\")\n",
 
 
608
  " else:\n",
609
  " print(f\"{tn:15s} | no JSON extracted\")"
610
  ]
@@ -613,7 +512,7 @@
613
  "cell_type": "markdown",
614
  "metadata": {},
615
  "source": [
616
- "## 11. Save Model"
617
  ]
618
  },
619
  {
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# Optigami — GRPO Training with OpenEnv\n",
8
  "\n",
9
  "Train an LLM to generate FOLD-format crease patterns using **GRPO** with **Unsloth** + **TRL**.\n",
10
  "\n",
11
+ "**Rewards come from the OpenEnv environment** deployed on HF Spaces the training loop\n",
12
+ "calls `POST /reset` and `POST /step` on the live environment to get simulation results.\n",
13
  "\n",
14
  "**Environment**: [openenv-community/optigami_](https://huggingface.co/spaces/openenv-community/optigami_) (OpenEnv 0.2.1)\n",
15
  "\n",
 
45
  "elif importlib.util.find_spec(\"unsloth\") is None:\n",
46
  " !pip install -qqq unsloth trackio\n",
47
  "!pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n",
48
+ "!pip install -qqq scipy datasets requests"
49
  ]
50
  },
51
  {
52
  "cell_type": "markdown",
53
  "metadata": {},
54
  "source": [
55
+ "## 2. OpenEnv Client\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  "\n",
57
+ "Connect to the Optigami OpenEnv environment deployed on HF Spaces.\n",
58
+ "The environment handles fold simulation and reward computation via its API:\n",
59
+ "- `POST /api/env/reset` — start a new episode for a task\n",
60
+ "- `POST /api/env/step` — submit a FOLD crease pattern, get back reward + shape similarity"
61
  ]
62
  },
63
  {
 
66
  "metadata": {},
67
  "outputs": [],
68
  "source": [
69
+ "import requests\n",
70
+ "import json\n",
71
+ "import re\n",
72
+ "\n",
73
+ "# OpenEnv environment URL deployed on HF Spaces\n",
74
+ "OPENENV_URL = \"https://openenv-community-optigami-c92c300.hf.space/api/env\"\n",
75
+ "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  "\n",
77
+ "class OptigamiEnvClient:\n",
78
+ " \"\"\"Client for the Optigami OpenEnv environment on HF Spaces.\"\"\"\n",
79
+ "\n",
80
+ " def __init__(self, base_url=OPENENV_URL):\n",
81
+ " self.base_url = base_url\n",
82
+ " self.session = requests.Session()\n",
83
+ " self.session.headers.update({\"Content-Type\": \"application/json\"})\n",
84
+ "\n",
85
+ " def health(self):\n",
86
+ " \"\"\"Check if the environment is reachable.\"\"\"\n",
87
+ " try:\n",
88
+ " r = self.session.get(f\"{self.base_url}/health\", timeout=10)\n",
89
+ " return r.status_code == 200\n",
90
+ " except:\n",
91
+ " return False\n",
92
+ "\n",
93
+ " def reset(self, task_name=\"triangle\"):\n",
94
+ " \"\"\"Reset environment for a task. Returns observation dict.\"\"\"\n",
95
+ " r = self.session.post(\n",
96
+ " f\"{self.base_url}/reset\",\n",
97
+ " json={\"task_name\": task_name},\n",
98
+ " timeout=30,\n",
99
+ " )\n",
100
+ " r.raise_for_status()\n",
101
+ " return r.json()\n",
102
+ "\n",
103
+ " def step(self, fold_data):\n",
104
+ " \"\"\"Submit a FOLD crease pattern. Returns observation with reward.\"\"\"\n",
105
+ " r = self.session.post(\n",
106
+ " f\"{self.base_url}/step\",\n",
107
+ " json={\"action\": {\"fold_data\": fold_data}},\n",
108
+ " timeout=30,\n",
109
+ " )\n",
110
+ " r.raise_for_status()\n",
111
+ " return r.json()\n",
112
+ "\n",
113
+ " def get_tasks(self):\n",
114
+ " \"\"\"Fetch available tasks from the environment.\"\"\"\n",
115
+ " r = self.session.get(f\"{self.base_url.replace('/api/env', '')}/tasks\", timeout=10)\n",
116
+ " r.raise_for_status()\n",
117
+ " return r.json()\n",
118
+ "\n",
119
+ "\n",
120
+ "# Initialize client and verify connection\n",
121
+ "env = OptigamiEnvClient()\n",
122
+ "if env.health():\n",
123
+ " print(f\"Connected to OpenEnv at {OPENENV_URL}\")\n",
124
+ " tasks = env.get_tasks()\n",
125
+ " print(f\"Available tasks: {list(tasks.keys())}\")\n",
126
+ "else:\n",
127
+ " print(f\"WARNING: Cannot reach OpenEnv at {OPENENV_URL}\")\n",
128
+ " print(\"Make sure the HF Space is running!\")"
129
  ]
130
  },
131
  {
132
  "cell_type": "markdown",
133
  "metadata": {},
134
  "source": [
135
+ "## 3. Test the OpenEnv Environment\n",
136
  "\n",
137
+ "Verify the full loop: reset step get reward."
138
  ]
139
  },
140
  {
 
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "# Reset for triangle task\n",
147
+ "obs = env.reset(task_name=\"triangle\")\n",
148
+ "print(\"Reset observation:\")\n",
149
+ "print(f\" Task: {obs['observation']['task']}\")\n",
150
+ "print(f\" Done: {obs['observation']['done']}\")\n",
151
+ "print(f\" Target positions: {len(obs['observation']['target_positions'])} vertices\")\n",
152
+ "\n",
153
+ "# Submit the known-correct triangle fold\n",
154
+ "correct_fold = {\n",
155
+ " \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1]],\n",
156
+ " \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n",
157
+ " \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
158
+ " \"edges_foldAngle\": [0,0,0,0,180],\n",
159
+ " \"faces_vertices\": [[0,1,2],[0,2,3]],\n",
160
+ "}\n",
161
+ "result = env.step(correct_fold)\n",
162
+ "print(f\"\\nStep result:\")\n",
163
+ "print(f\" Reward: {result['reward']}\")\n",
164
+ "print(f\" Shape similarity: {result['observation']['shape_similarity']}\")\n",
165
+ "print(f\" Done: {result['observation']['done']}\")\n",
166
+ "print(f\" Is stable: {result['observation']['is_stable']}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  ]
168
  },
169
  {
170
  "cell_type": "markdown",
171
  "metadata": {},
172
  "source": [
173
+ "## 4. Reward Functions (via OpenEnv API)\n",
174
  "\n",
175
+ "- **`valid_fold_reward`**: Local JSON validation (fast, no API call needed)\n",
176
+ "- **`openenv_reward`**: Calls the OpenEnv environment to simulate the fold and get the real reward"
177
  ]
178
  },
179
  {
 
185
  "PRINTER = 0\n",
186
  "\n",
187
  "def extract_fold_json(response):\n",
188
+ " \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n",
189
  " m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
190
  " if m:\n",
191
  " try: return json.loads(m.group(1))\n",
 
202
  "\n",
203
  "\n",
204
  "def valid_fold_reward(completions, **kwargs):\n",
205
+ " \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
206
+ " REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n",
207
  " scores = []\n",
208
  " for c in completions:\n",
209
  " fold = extract_fold_json(c[0][\"content\"])\n",
210
+ " if fold is None:\n",
211
+ " scores.append(-2.0)\n",
212
+ " continue\n",
213
+ " # Basic structural checks\n",
214
+ " if not REQUIRED.issubset(fold.keys()):\n",
215
+ " scores.append(-0.5); continue\n",
216
+ " verts = fold[\"vertices_coords\"]\n",
217
+ " edges = fold[\"edges_vertices\"]\n",
218
+ " asgn = fold[\"edges_assignment\"]\n",
219
+ " if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n",
220
+ " scores.append(-0.5); continue\n",
221
+ " if not any(a in (\"M\",\"V\") for a in asgn):\n",
222
+ " scores.append(-0.5); continue\n",
223
+ " if not any(a == \"B\" for a in asgn):\n",
224
+ " scores.append(-0.5); continue\n",
225
+ " scores.append(1.0)\n",
226
  " return scores\n",
227
  "\n",
228
  "\n",
229
+ "def openenv_reward(completions, task_name, **kwargs):\n",
230
+ " \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n",
231
+ "\n",
232
+ " Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n",
233
+ " The environment runs the fold simulation and computes shape similarity.\n",
234
+ " \"\"\"\n",
235
  " global PRINTER\n",
236
+ " # task_name comes as a list from the dataset\n",
237
+ " tn = task_name[0] if isinstance(task_name, list) else task_name\n",
 
 
 
 
 
 
238
  "\n",
239
  " scores = []\n",
240
  " for c in completions:\n",
241
  " resp = c[0][\"content\"]\n",
242
+ "\n",
243
+ " # Periodic logging\n",
244
  " if PRINTER % 10 == 0:\n",
245
  " print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
246
  " PRINTER += 1\n",
247
+ "\n",
248
+ " # Parse the FOLD JSON from the LLM response\n",
249
  " fold = extract_fold_json(resp)\n",
250
+ " if fold is None:\n",
251
+ " scores.append(-2.0)\n",
252
+ " continue\n",
253
+ "\n",
254
  " try:\n",
255
+ " # Reset environment for this task\n",
256
+ " env.reset(task_name=tn)\n",
257
+ "\n",
258
+ " # Submit the fold to OpenEnv — environment simulates and scores it\n",
259
+ " result = env.step(fold)\n",
260
+ "\n",
261
+ " # Get reward from the environment\n",
262
+ " reward = result.get(\"reward\", None)\n",
263
+ " if reward is not None:\n",
264
+ " scores.append(float(reward))\n",
265
+ " else:\n",
266
+ " # Fallback: extract from observation\n",
267
+ " obs = result.get(\"observation\", {})\n",
268
+ " if obs.get(\"error\"):\n",
269
+ " scores.append(-2.0)\n",
270
+ " else:\n",
271
+ " sim = obs.get(\"shape_similarity\", 0.0)\n",
272
+ " scores.append(float(sim) * 20.0)\n",
273
+ "\n",
274
+ " except requests.exceptions.RequestException as e:\n",
275
+ " print(f\"OpenEnv API error: {e}\")\n",
276
+ " scores.append(-1.0)\n",
277
+ " except Exception as e:\n",
278
+ " print(f\"Reward error: {e}\")\n",
279
+ " scores.append(-1.0)\n",
280
+ "\n",
281
  " return scores\n",
282
  "\n",
283
+ "\n",
284
+ "print(\"Reward functions ready (valid_fold=local, openenv_reward=API).\")"
285
  ]
286
  },
287
  {
288
  "cell_type": "markdown",
289
  "metadata": {},
290
  "source": [
291
+ "## 5. Prompt Template & Multi-Task Dataset\n",
292
+ "\n",
293
+ "Each row is a different task — the model must generalize across shapes."
294
  ]
295
  },
296
  {
 
299
  "metadata": {},
300
  "outputs": [],
301
  "source": [
302
+ "from datasets import Dataset\n",
303
+ "\n",
304
+ "PROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\n",
305
+ "that, when folded, produces the target shape described below.\n",
306
+ "\n",
307
+ "Target: {description}\n",
308
+ "Paper size: {width} x {height}\n",
309
+ "\n",
310
+ "Output a JSON object with these exact fields:\n",
311
+ "- vertices_coords: [[x, y], ...] — 2D positions on the flat paper\n",
312
+ "- edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges\n",
313
+ "- edges_assignment: [\"B\"|\"M\"|\"V\", ...] — B=boundary, M=mountain fold, V=valley fold\n",
314
+ "- edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative, V: positive, B: 0)\n",
315
+ "\n",
316
+ "Rules:\n",
317
+ "- Boundary edges (B) must outline the paper rectangle\n",
318
+ "- At least one fold crease (M or V) must exist\n",
319
+ "- Mountain fold angles are negative (-180 to 0)\n",
320
+ "- Valley fold angles are positive (0 to 180)\n",
321
+ "- All vertex indices in edges must be valid (0 to N-1)\n",
322
+ "\n",
323
+ "Output ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n",
324
+ "\n",
325
+ "\n",
326
+ "def build_prompt(task):\n",
327
+ " return PROMPT_TEMPLATE.format(\n",
328
+ " description=task[\"description\"],\n",
329
+ " width=task[\"paper\"][\"width\"],\n",
330
+ " height=task[\"paper\"][\"height\"],\n",
331
+ " )\n",
332
+ "\n",
333
+ "\n",
334
+ "# Fetch tasks from the OpenEnv environment\n",
335
+ "env_tasks = env.get_tasks()\n",
336
+ "task_names = list(env_tasks.keys())\n",
337
+ "print(f\"Tasks from OpenEnv: {task_names}\")\n",
338
+ "\n",
339
+ "# Build multi-task dataset\n",
340
+ "rows = []\n",
341
+ "for i in range(1000):\n",
342
+ " tn = task_names[i % len(task_names)]\n",
343
+ " t = env_tasks[tn]\n",
344
+ " rows.append({\n",
345
+ " \"prompt\": [{\"role\": \"user\", \"content\": build_prompt(t)}],\n",
346
+ " \"task_name\": tn,\n",
347
+ " \"answer\": 0,\n",
348
+ " })\n",
349
+ "\n",
350
+ "grpo_dataset = Dataset.from_list(rows)\n",
351
+ "print(f\"Dataset: {len(grpo_dataset)} rows across {len(task_names)} tasks\")"
352
  ]
353
  },
354
  {
355
  "cell_type": "markdown",
356
  "metadata": {},
357
  "source": [
358
+ "## 6. Configuration"
359
  ]
360
  },
361
  {
 
389
  "cell_type": "markdown",
390
  "metadata": {},
391
  "source": [
392
+ "## 7. Load Model + LoRA"
393
  ]
394
  },
395
  {
 
418
  "cell_type": "markdown",
419
  "metadata": {},
420
  "source": [
421
+ "## 8. GRPO Training\n",
422
+ "\n",
423
+ "Each training step:\n",
424
+ "1. Pick a task prompt (rotates through all 7 origami tasks)\n",
425
+ "2. Generate N completions (FOLD JSON attempts)\n",
426
+ "3. Score each: `valid_fold_reward` (local) + `openenv_reward` (calls HF Space API)\n",
427
+ "4. GRPO computes group-relative advantages and updates the policy\n",
428
  "\n",
429
+ "The **OpenEnv environment** on HF Spaces handles the fold simulation and shape comparison."
 
 
 
 
430
  ]
431
  },
432
  {
 
459
  "trainer = GRPOTrainer(\n",
460
  " model=model,\n",
461
  " processing_class=tokenizer,\n",
462
+ " reward_funcs=[valid_fold_reward, openenv_reward],\n",
463
  " args=training_args,\n",
464
  " train_dataset=grpo_dataset,\n",
465
  ")\n",
466
  "\n",
467
+ "print(\"Starting GRPO training (rewards from OpenEnv API)...\")\n",
468
  "trainer.train()\n",
469
  "print(\"Done!\")"
470
  ]
 
473
  "cell_type": "markdown",
474
  "metadata": {},
475
  "source": [
476
+ "## 9. Test Across Tasks"
477
  ]
478
  },
479
  {
 
484
  "source": [
485
  "FastLanguageModel.for_inference(model)\n",
486
  "\n",
487
+ "for tn in list(env_tasks.keys())[:4]:\n",
488
+ " task = env_tasks[tn]\n",
489
  " prompt = build_prompt(task)\n",
490
  " inputs = tokenizer.apply_chat_template(\n",
491
  " [{\"role\": \"user\", \"content\": prompt}],\n",
 
496
  "\n",
497
  " fold = extract_fold_json(response)\n",
498
  " if fold:\n",
499
+ " try:\n",
500
+ " env.reset(task_name=tn)\n",
501
+ " result = env.step(fold)\n",
502
+ " reward = result.get(\"reward\", 0)\n",
503
+ " sim = result[\"observation\"].get(\"shape_similarity\", 0)\n",
504
+ " print(f\"{tn:15s} | reward={reward:6.2f} | similarity={sim:.3f}\")\n",
505
+ " except Exception as e:\n",
506
+ " print(f\"{tn:15s} | env error: {e}\")\n",
507
  " else:\n",
508
  " print(f\"{tn:15s} | no JSON extracted\")"
509
  ]
 
512
  "cell_type": "markdown",
513
  "metadata": {},
514
  "source": [
515
+ "## 10. Save Model"
516
  ]
517
  },
518
  {