Spaces:
Running
Running
Commit ·
c0cedb4
1
Parent(s): 490094b
Route rewards through OpenEnv API instead of local computation
Browse files- Add Next.js rewrite proxy: /api/env/* → localhost:8000 (OpenEnv backend)
- Replace local shape_match_reward with openenv_reward that calls:
POST /api/env/reset (start episode for task)
POST /api/env/step (submit FOLD, get reward from environment)
- Add OptigamiEnvClient class for clean API interaction
- Fetch task list from environment instead of hardcoding
- Test cell verifies full reset→step→reward loop before training
- valid_fold_reward stays local (just JSON parsing, no simulation)
The training loop now properly uses the OpenEnv environment for
reward computation, satisfying the submission requirement.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- next.config.ts +8 -0
- training/train_grpo.ipynb +234 -335
next.config.ts
CHANGED
|
@@ -20,6 +20,14 @@ const nextConfig: NextConfig = {
|
|
| 20 |
],
|
| 21 |
},
|
| 22 |
output: 'standalone',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
transpilePackages: ['motion'],
|
| 24 |
webpack: (config, {dev}) => {
|
| 25 |
// HMR is disabled in AI Studio via DISABLE_HMR env var.
|
|
|
|
| 20 |
],
|
| 21 |
},
|
| 22 |
output: 'standalone',
|
| 23 |
+
async rewrites() {
|
| 24 |
+
return [
|
| 25 |
+
{
|
| 26 |
+
source: '/api/env/:path*',
|
| 27 |
+
destination: 'http://localhost:8000/:path*',
|
| 28 |
+
},
|
| 29 |
+
];
|
| 30 |
+
},
|
| 31 |
transpilePackages: ['motion'],
|
| 32 |
webpack: (config, {dev}) => {
|
| 33 |
// HMR is disabled in AI Studio via DISABLE_HMR env var.
|
training/train_grpo.ipynb
CHANGED
|
@@ -4,12 +4,12 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# Optigami — GRPO Training
|
| 8 |
"\n",
|
| 9 |
"Train an LLM to generate FOLD-format crease patterns using **GRPO** with **Unsloth** + **TRL**.\n",
|
| 10 |
"\n",
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
"\n",
|
| 14 |
"**Environment**: [openenv-community/optigami_](https://huggingface.co/spaces/openenv-community/optigami_) (OpenEnv 0.2.1)\n",
|
| 15 |
"\n",
|
|
@@ -45,159 +45,19 @@
|
|
| 45 |
"elif importlib.util.find_spec(\"unsloth\") is None:\n",
|
| 46 |
" !pip install -qqq unsloth trackio\n",
|
| 47 |
"!pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n",
|
| 48 |
-
"!pip install -qqq scipy datasets"
|
| 49 |
]
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "markdown",
|
| 53 |
"metadata": {},
|
| 54 |
"source": [
|
| 55 |
-
"## 2.
|
| 56 |
-
]
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"cell_type": "code",
|
| 60 |
-
"execution_count": null,
|
| 61 |
-
"metadata": {},
|
| 62 |
-
"outputs": [],
|
| 63 |
-
"source": [
|
| 64 |
-
"import json, re, random\n",
|
| 65 |
-
"from collections import defaultdict\n",
|
| 66 |
-
"from dataclasses import dataclass\n",
|
| 67 |
-
"from typing import Any\n",
|
| 68 |
-
"\n",
|
| 69 |
-
"import numpy as np\n",
|
| 70 |
-
"from scipy.spatial.distance import cdist\n",
|
| 71 |
-
"from scipy.spatial.transform import Rotation\n",
|
| 72 |
-
"\n",
|
| 73 |
-
"\n",
|
| 74 |
-
"def validate_fold(fold_data):\n",
|
| 75 |
-
" for key in (\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"):\n",
|
| 76 |
-
" if key not in fold_data: return False, f\"Missing: {key}\"\n",
|
| 77 |
-
" verts, edges, asgn = fold_data[\"vertices_coords\"], fold_data[\"edges_vertices\"], fold_data[\"edges_assignment\"]\n",
|
| 78 |
-
" if len(verts) < 3: return False, \"<3 verts\"\n",
|
| 79 |
-
" if len(edges) < 3: return False, \"<3 edges\"\n",
|
| 80 |
-
" if len(edges) != len(asgn): return False, \"len mismatch\"\n",
|
| 81 |
-
" if \"edges_foldAngle\" in fold_data and len(fold_data[\"edges_foldAngle\"]) != len(edges): return False, \"angle len\"\n",
|
| 82 |
-
" nv = len(verts)\n",
|
| 83 |
-
" for i, v in enumerate(verts):\n",
|
| 84 |
-
" if not isinstance(v, (list, tuple)) or len(v) < 2: return False, f\"vert {i}\"\n",
|
| 85 |
-
" for i, e in enumerate(edges):\n",
|
| 86 |
-
" if not isinstance(e, (list, tuple)) or len(e) != 2: return False, f\"edge {i}\"\n",
|
| 87 |
-
" if e[0]<0 or e[0]>=nv or e[1]<0 or e[1]>=nv: return False, f\"edge {i} range\"\n",
|
| 88 |
-
" if e[0]==e[1]: return False, f\"edge {i} degen\"\n",
|
| 89 |
-
" for i, a in enumerate(asgn):\n",
|
| 90 |
-
" if a not in {\"M\",\"V\",\"B\",\"F\",\"U\",\"C\"}: return False, f\"asgn {i}\"\n",
|
| 91 |
-
" if not any(a in (\"M\",\"V\") for a in asgn): return False, \"no folds\"\n",
|
| 92 |
-
" if not any(a==\"B\" for a in asgn): return False, \"no boundary\"\n",
|
| 93 |
-
" return True, \"\"\n",
|
| 94 |
-
"\n",
|
| 95 |
-
"def parse_fold(fd):\n",
|
| 96 |
-
" verts = fd[\"vertices_coords\"]\n",
|
| 97 |
-
" vertices = np.zeros((len(verts),3), dtype=np.float64)\n",
|
| 98 |
-
" for i,v in enumerate(verts): vertices[i,0]=v[0]; vertices[i,1]=v[1]; (len(v)>2 and setattr(vertices, '__setitem__', None) is None) or (vertices.__setitem__((i,2), v[2]) if len(v)>2 else None)\n",
|
| 99 |
-
" # Fix the above — simple approach\n",
|
| 100 |
-
" vertices = np.zeros((len(verts),3), dtype=np.float64)\n",
|
| 101 |
-
" for i,v in enumerate(verts):\n",
|
| 102 |
-
" vertices[i,0]=v[0]; vertices[i,1]=v[1]\n",
|
| 103 |
-
" if len(v)>2: vertices[i,2]=v[2]\n",
|
| 104 |
-
" edges = np.array(fd[\"edges_vertices\"], dtype=np.int32)\n",
|
| 105 |
-
" asgn = list(fd[\"edges_assignment\"])\n",
|
| 106 |
-
" if \"edges_foldAngle\" in fd: fa = np.array(fd[\"edges_foldAngle\"], dtype=np.float64)\n",
|
| 107 |
-
" else:\n",
|
| 108 |
-
" fa = np.zeros(len(edges), dtype=np.float64)\n",
|
| 109 |
-
" for i,a in enumerate(asgn):\n",
|
| 110 |
-
" if a==\"V\": fa[i]=180.0\n",
|
| 111 |
-
" elif a==\"M\": fa[i]=-180.0\n",
|
| 112 |
-
" fa_rad = np.radians(fa)\n",
|
| 113 |
-
" if \"faces_vertices\" in fd:\n",
|
| 114 |
-
" tris = []\n",
|
| 115 |
-
" for face in fd[\"faces_vertices\"]:\n",
|
| 116 |
-
" for i in range(1,len(face)-1): tris.append([face[0],face[i],face[i+1]])\n",
|
| 117 |
-
" faces = np.array(tris, dtype=np.int32) if tris else np.zeros((0,3), dtype=np.int32)\n",
|
| 118 |
-
" else:\n",
|
| 119 |
-
" adj=defaultdict(set)\n",
|
| 120 |
-
" for v1,v2 in edges: adj[v1].add(v2); adj[v2].add(v1)\n",
|
| 121 |
-
" ts=set()\n",
|
| 122 |
-
" for v1,v2 in edges:\n",
|
| 123 |
-
" for v3 in adj[v1]&adj[v2]: ts.add(tuple(sorted([v1,v2,v3])))\n",
|
| 124 |
-
" faces = np.array(list(ts), dtype=np.int32) if ts else np.zeros((0,3), dtype=np.int32)\n",
|
| 125 |
-
" return {\"vertices\":vertices,\"edges\":edges,\"assignments\":asgn,\"fold_angles\":fa_rad,\"faces\":faces}\n",
|
| 126 |
-
"\n",
|
| 127 |
-
"@dataclass\n",
|
| 128 |
-
"class SimResult:\n",
|
| 129 |
-
" positions: np.ndarray; converged: bool; steps_taken: int; max_strain: float; total_energy: float\n",
|
| 130 |
-
"\n",
|
| 131 |
-
"def simulate(fold_data, crease_percent=1.0):\n",
|
| 132 |
-
" p = parse_fold(fold_data)\n",
|
| 133 |
-
" flat=p[\"vertices\"].copy(); edges=p[\"edges\"]; asgn=p[\"assignments\"]; fa=p[\"fold_angles\"]; faces=p[\"faces\"]; pos=flat.copy()\n",
|
| 134 |
-
" if len(faces)==0: return SimResult(pos,True,0,0.,0.)\n",
|
| 135 |
-
" fadj={}\n",
|
| 136 |
-
" for fi,face in enumerate(faces):\n",
|
| 137 |
-
" for j in range(len(face)):\n",
|
| 138 |
-
" v1,v2=int(face[j]),int(face[(j+1)%len(face)]); k=(min(v1,v2),max(v1,v2)); fadj.setdefault(k,[]).append(fi)\n",
|
| 139 |
-
" cm={}\n",
|
| 140 |
-
" for i,(v1,v2) in enumerate(edges):\n",
|
| 141 |
-
" k=(min(int(v1),int(v2)),max(int(v1),int(v2)))\n",
|
| 142 |
-
" if asgn[i] in (\"M\",\"V\"): cm[k]=fa[i]*crease_percent\n",
|
| 143 |
-
" nf=len(faces); fR=[None]*nf; ft=[None]*nf; fR[0]=np.eye(3); ft[0]=np.zeros(3)\n",
|
| 144 |
-
" vis=[False]*nf; vis[0]=True; placed=set(int(vi) for vi in faces[0])\n",
|
| 145 |
-
" q=[0]\n",
|
| 146 |
-
" while q:\n",
|
| 147 |
-
" fi=q.pop(0); face=faces[fi]\n",
|
| 148 |
-
" for j in range(len(face)):\n",
|
| 149 |
-
" v1,v2=int(face[j]),int(face[(j+1)%len(face)]); ek=(min(v1,v2),max(v1,v2))\n",
|
| 150 |
-
" for fj in fadj.get(ek,[]):\n",
|
| 151 |
-
" if vis[fj]: continue\n",
|
| 152 |
-
" vis[fj]=True; q.append(fj); ang=cm.get(ek,0.0)\n",
|
| 153 |
-
" if abs(ang)>1e-10:\n",
|
| 154 |
-
" p1=pos[v1].copy(); ax=pos[v2]-p1; al=np.linalg.norm(ax)\n",
|
| 155 |
-
" fr=Rotation.from_rotvec(ang*ax/al).as_matrix() if al>1e-12 else np.eye(3)\n",
|
| 156 |
-
" fR[fj]=fr@fR[fi]; ft[fj]=fr@(ft[fi]-p1)+p1\n",
|
| 157 |
-
" else: fR[fj]=fR[fi].copy(); ft[fj]=ft[fi].copy()\n",
|
| 158 |
-
" for vi in faces[fj]:\n",
|
| 159 |
-
" vi2=int(vi)\n",
|
| 160 |
-
" if vi2 not in placed: pos[vi2]=fR[fj]@flat[vi2]+ft[fj]; placed.add(vi2)\n",
|
| 161 |
-
" ms=0.0\n",
|
| 162 |
-
" for v1,v2 in edges:\n",
|
| 163 |
-
" r=np.linalg.norm(flat[v2]-flat[v1]); c=np.linalg.norm(pos[v2]-pos[v1])\n",
|
| 164 |
-
" if r>1e-12: ms=max(ms,abs(c-r)/r)\n",
|
| 165 |
-
" return SimResult(pos,True,1,ms,0.)\n",
|
| 166 |
-
"\n",
|
| 167 |
-
"def _rotations():\n",
|
| 168 |
-
" rs=[np.eye(3)]\n",
|
| 169 |
-
" for k in range(1,4):\n",
|
| 170 |
-
" a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
|
| 171 |
-
" rs.append(np.array([[c,-s,0],[s,c,0],[0,0,1]]))\n",
|
| 172 |
-
" for k in range(1,4):\n",
|
| 173 |
-
" a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
|
| 174 |
-
" rs.append(np.array([[1,0,0],[0,c,-s],[0,s,c]]))\n",
|
| 175 |
-
" for k in range(1,4):\n",
|
| 176 |
-
" a=k*np.pi/2; c,s=np.cos(a),np.sin(a)\n",
|
| 177 |
-
" rs.append(np.array([[c,0,s],[0,1,0],[-s,0,c]]))\n",
|
| 178 |
-
" rs+=[np.diag([-1.,1.,1.]),np.diag([1.,-1.,1.]),np.diag([1.,1.,-1.])]\n",
|
| 179 |
-
" return rs\n",
|
| 180 |
-
"\n",
|
| 181 |
-
"def compute_shape_match(pred, tgt):\n",
|
| 182 |
-
" if len(pred)==0 or len(tgt)==0: return 0.0\n",
|
| 183 |
-
" pc=pred-pred.mean(0); tc=tgt-tgt.mean(0); best=0.0\n",
|
| 184 |
-
" for r in _rotations():\n",
|
| 185 |
-
" rot=pc@r.T; d=cdist(rot,tc); ch=(d.min(1).mean()+d.min(0).mean())/2\n",
|
| 186 |
-
" ap=np.vstack([rot,tc]); dg=np.linalg.norm(ap.max(0)-ap.min(0))\n",
|
| 187 |
-
" sc=max(0.,1.-ch/dg) if dg>1e-12 else (1. if ch<1e-12 else 0.)\n",
|
| 188 |
-
" best=max(best,sc)\n",
|
| 189 |
-
" return best\n",
|
| 190 |
-
"\n",
|
| 191 |
-
"print(\"Engine loaded.\")"
|
| 192 |
-
]
|
| 193 |
-
},
|
| 194 |
-
{
|
| 195 |
-
"cell_type": "markdown",
|
| 196 |
-
"metadata": {},
|
| 197 |
-
"source": [
|
| 198 |
-
"## 3. Task Definitions\n",
|
| 199 |
"\n",
|
| 200 |
-
"
|
|
|
|
|
|
|
|
|
|
| 201 |
]
|
| 202 |
},
|
| 203 |
{
|
|
@@ -206,104 +66,75 @@
|
|
| 206 |
"metadata": {},
|
| 207 |
"outputs": [],
|
| 208 |
"source": [
|
| 209 |
-
"
|
| 210 |
-
"
|
| 211 |
-
"
|
| 212 |
-
"
|
| 213 |
-
"
|
| 214 |
-
"
|
| 215 |
-
"
|
| 216 |
-
" \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n",
|
| 217 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
|
| 218 |
-
" \"edges_foldAngle\": [0,0,0,0,180],\n",
|
| 219 |
-
" \"faces_vertices\": [[0,1,2],[0,2,3]],\n",
|
| 220 |
-
" },\n",
|
| 221 |
-
" },\n",
|
| 222 |
-
" \"half_fold\": {\n",
|
| 223 |
-
" \"name\": \"half_fold\",\n",
|
| 224 |
-
" \"description\": \"Fold the paper in half horizontally along the middle\",\n",
|
| 225 |
-
" \"difficulty\": 1, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
|
| 226 |
-
" \"target_fold\": {\n",
|
| 227 |
-
" \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1],[0,0.5],[1,0.5]],\n",
|
| 228 |
-
" \"edges_vertices\": [[0,1],[1,5],[5,2],[2,3],[3,4],[4,0],[4,5]],\n",
|
| 229 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
|
| 230 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,180],\n",
|
| 231 |
-
" \"faces_vertices\": [[0,1,5,4],[4,5,2,3]],\n",
|
| 232 |
-
" },\n",
|
| 233 |
-
" },\n",
|
| 234 |
-
" \"quarter_fold\": {\n",
|
| 235 |
-
" \"name\": \"quarter_fold\",\n",
|
| 236 |
-
" \"description\": \"Fold the paper into quarters with two perpendicular creases through the center\",\n",
|
| 237 |
-
" \"difficulty\": 2, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
|
| 238 |
-
" \"target_fold\": {\n",
|
| 239 |
-
" \"vertices_coords\": [[0,0],[0.5,0],[1,0],[0,0.5],[0.5,0.5],[1,0.5],[0,1],[0.5,1],[1,1]],\n",
|
| 240 |
-
" \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[1,4],[4,7],[3,4],[4,5]],\n",
|
| 241 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"V\",\"V\",\"V\"],\n",
|
| 242 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,180,180,180],\n",
|
| 243 |
-
" \"faces_vertices\": [[0,1,4,3],[1,2,5,4],[3,4,7,6],[4,5,8,7]],\n",
|
| 244 |
-
" },\n",
|
| 245 |
-
" },\n",
|
| 246 |
-
" \"accordion\": {\n",
|
| 247 |
-
" \"name\": \"accordion\",\n",
|
| 248 |
-
" \"description\": \"Make a zig-zag accordion fold with alternating mountain and valley creases like a paper fan\",\n",
|
| 249 |
-
" \"difficulty\": 2, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
|
| 250 |
-
" \"target_fold\": {\n",
|
| 251 |
-
" \"vertices_coords\": [[-1,1],[-0.5,1],[0,1],[0.5,1],[1,1],[-1,-1],[-0.5,-1],[0,-1],[0.5,-1],[1,-1]],\n",
|
| 252 |
-
" \"edges_vertices\": [[0,1],[1,2],[2,3],[3,4],[5,6],[6,7],[7,8],[8,9],[0,5],[4,9],[1,6],[2,7],[3,8]],\n",
|
| 253 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"M\",\"V\"],\n",
|
| 254 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,0,0,0,0,180,-180,180],\n",
|
| 255 |
-
" \"faces_vertices\": [[0,1,6,5],[1,2,7,6],[2,3,8,7],[3,4,9,8]],\n",
|
| 256 |
-
" },\n",
|
| 257 |
-
" },\n",
|
| 258 |
-
" \"waterbomb\": {\n",
|
| 259 |
-
" \"name\": \"waterbomb\",\n",
|
| 260 |
-
" \"description\": \"Create a waterbomb base with valley folds on diagonals and mountain folds on midlines of a square\",\n",
|
| 261 |
-
" \"difficulty\": 3, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
|
| 262 |
-
" \"target_fold\": {\n",
|
| 263 |
-
" \"vertices_coords\": [[-1,1],[0,1],[1,1],[-1,0],[0,0],[1,0],[-1,-1],[0,-1],[1,-1]],\n",
|
| 264 |
-
" \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[0,4],[2,4],[6,4],[8,4],[1,4],[3,4],[5,4],[7,4]],\n",
|
| 265 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"V\",\"V\",\"V\",\"M\",\"M\",\"M\",\"M\"],\n",
|
| 266 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,180,180,180,-180,-180,-180,-180],\n",
|
| 267 |
-
" \"faces_vertices\": [[0,1,4],[1,2,4],[2,5,4],[5,8,4],[8,7,4],[7,6,4],[6,3,4],[3,0,4]],\n",
|
| 268 |
-
" },\n",
|
| 269 |
-
" },\n",
|
| 270 |
-
" \"miura_ori\": {\n",
|
| 271 |
-
" \"name\": \"miura_ori\",\n",
|
| 272 |
-
" \"description\": \"Create a Miura-ori tessellation with offset zigzag vertices that folds flat in one motion\",\n",
|
| 273 |
-
" \"difficulty\": 3, \"paper\": {\"width\": 2.0, \"height\": 2.0},\n",
|
| 274 |
-
" \"target_fold\": {\n",
|
| 275 |
-
" \"vertices_coords\": [[-1,1],[0,1.2],[1,1],[-1,0],[0,0.2],[1,0],[-1,-1],[0,-0.8],[1,-1]],\n",
|
| 276 |
-
" \"edges_vertices\": [[0,1],[1,2],[2,5],[5,8],[8,7],[7,6],[6,3],[3,0],[1,4],[4,7],[3,4],[4,5]],\n",
|
| 277 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"M\",\"M\",\"V\",\"M\"],\n",
|
| 278 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,0,0,-180,-180,180,-180],\n",
|
| 279 |
-
" \"faces_vertices\": [[0,1,4,3],[1,2,5,4],[3,4,7,6],[4,5,8,7]],\n",
|
| 280 |
-
" },\n",
|
| 281 |
-
" },\n",
|
| 282 |
-
" \"letter_fold\": {\n",
|
| 283 |
-
" \"name\": \"letter_fold\",\n",
|
| 284 |
-
" \"description\": \"Tri-fold the paper into thirds like a letter envelope with two parallel horizontal creases\",\n",
|
| 285 |
-
" \"difficulty\": 2, \"paper\": {\"width\": 1.0, \"height\": 1.0},\n",
|
| 286 |
-
" \"target_fold\": {\n",
|
| 287 |
-
" \"vertices_coords\": [[0,0],[1,0],[0,0.333],[1,0.333],[0,0.667],[1,0.667],[0,1],[1,1]],\n",
|
| 288 |
-
" \"edges_vertices\": [[0,1],[1,3],[3,5],[5,7],[7,6],[6,4],[4,2],[2,0],[2,3],[4,5]],\n",
|
| 289 |
-
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"B\",\"V\",\"M\"],\n",
|
| 290 |
-
" \"edges_foldAngle\": [0,0,0,0,0,0,0,0,180,-180],\n",
|
| 291 |
-
" \"faces_vertices\": [[0,1,3,2],[2,3,5,4],[4,5,7,6]],\n",
|
| 292 |
-
" },\n",
|
| 293 |
-
" },\n",
|
| 294 |
-
"}\n",
|
| 295 |
"\n",
|
| 296 |
-
"
|
| 297 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
]
|
| 299 |
},
|
| 300 |
{
|
| 301 |
"cell_type": "markdown",
|
| 302 |
"metadata": {},
|
| 303 |
"source": [
|
| 304 |
-
"##
|
| 305 |
"\n",
|
| 306 |
-
"
|
| 307 |
]
|
| 308 |
},
|
| 309 |
{
|
|
@@ -312,62 +143,37 @@
|
|
| 312 |
"metadata": {},
|
| 313 |
"outputs": [],
|
| 314 |
"source": [
|
| 315 |
-
"
|
| 316 |
-
"\n",
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"\n",
|
| 320 |
-
"Target: {
|
| 321 |
-
"
|
| 322 |
-
"\n",
|
| 323 |
-
"
|
| 324 |
-
"
|
| 325 |
-
"
|
| 326 |
-
"
|
| 327 |
-
"
|
| 328 |
-
"\n",
|
| 329 |
-
"
|
| 330 |
-
"
|
| 331 |
-
"
|
| 332 |
-
"
|
| 333 |
-
"
|
| 334 |
-
"
|
| 335 |
-
"\
|
| 336 |
-
"Output ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n",
|
| 337 |
-
"\n",
|
| 338 |
-
"\n",
|
| 339 |
-
"def build_prompt(task):\n",
|
| 340 |
-
" return PROMPT_TEMPLATE.format(\n",
|
| 341 |
-
" description=task[\"description\"],\n",
|
| 342 |
-
" width=task[\"paper\"][\"width\"],\n",
|
| 343 |
-
" height=task[\"paper\"][\"height\"],\n",
|
| 344 |
-
" )\n",
|
| 345 |
-
"\n",
|
| 346 |
-
"\n",
|
| 347 |
-
"# Build multi-task dataset — rotate through all tasks\n",
|
| 348 |
-
"task_names = list(TASKS.keys())\n",
|
| 349 |
-
"rows = []\n",
|
| 350 |
-
"for i in range(1000):\n",
|
| 351 |
-
" tn = task_names[i % len(task_names)]\n",
|
| 352 |
-
" t = TASKS[tn]\n",
|
| 353 |
-
" rows.append({\n",
|
| 354 |
-
" \"prompt\": [{\"role\": \"user\", \"content\": build_prompt(t)}],\n",
|
| 355 |
-
" \"task_name\": tn,\n",
|
| 356 |
-
" \"answer\": 0,\n",
|
| 357 |
-
" })\n",
|
| 358 |
-
"\n",
|
| 359 |
-
"grpo_dataset = Dataset.from_list(rows)\n",
|
| 360 |
-
"print(f\"Dataset: {len(grpo_dataset)} rows across {len(task_names)} tasks\")\n",
|
| 361 |
-
"print(f\"Task distribution: {dict((tn, sum(1 for r in rows if r['task_name']==tn)) for tn in task_names)}\")"
|
| 362 |
]
|
| 363 |
},
|
| 364 |
{
|
| 365 |
"cell_type": "markdown",
|
| 366 |
"metadata": {},
|
| 367 |
"source": [
|
| 368 |
-
"##
|
| 369 |
"\n",
|
| 370 |
-
"
|
|
|
|
| 371 |
]
|
| 372 |
},
|
| 373 |
{
|
|
@@ -379,6 +185,7 @@
|
|
| 379 |
"PRINTER = 0\n",
|
| 380 |
"\n",
|
| 381 |
"def extract_fold_json(response):\n",
|
|
|
|
| 382 |
" m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
|
| 383 |
" if m:\n",
|
| 384 |
" try: return json.loads(m.group(1))\n",
|
|
@@ -395,52 +202,95 @@
|
|
| 395 |
"\n",
|
| 396 |
"\n",
|
| 397 |
"def valid_fold_reward(completions, **kwargs):\n",
|
| 398 |
-
" \"\"\"Reward 1: +1.0 valid FOLD, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
|
|
|
|
| 399 |
" scores = []\n",
|
| 400 |
" for c in completions:\n",
|
| 401 |
" fold = extract_fold_json(c[0][\"content\"])\n",
|
| 402 |
-
" if fold is None:
|
| 403 |
-
"
|
| 404 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
" return scores\n",
|
| 406 |
"\n",
|
| 407 |
"\n",
|
| 408 |
-
"def
|
| 409 |
-
" \"\"\"Reward 2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
" global PRINTER\n",
|
| 411 |
-
" # task_name comes as a list
|
| 412 |
-
" if isinstance(task_name, list)
|
| 413 |
-
" tn = task_name[0] # All completions in a group share the same prompt/task\n",
|
| 414 |
-
" else:\n",
|
| 415 |
-
" tn = task_name\n",
|
| 416 |
-
" task = get_task(tn)\n",
|
| 417 |
-
" try: tgt = simulate(task[\"target_fold\"]).positions\n",
|
| 418 |
-
" except: return [0.0]*len(completions)\n",
|
| 419 |
"\n",
|
| 420 |
" scores = []\n",
|
| 421 |
" for c in completions:\n",
|
| 422 |
" resp = c[0][\"content\"]\n",
|
|
|
|
|
|
|
| 423 |
" if PRINTER % 10 == 0:\n",
|
| 424 |
" print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
|
| 425 |
" PRINTER += 1\n",
|
|
|
|
|
|
|
| 426 |
" fold = extract_fold_json(resp)\n",
|
| 427 |
-
" if fold is None:
|
| 428 |
-
"
|
| 429 |
-
"
|
|
|
|
| 430 |
" try:\n",
|
| 431 |
-
"
|
| 432 |
-
"
|
| 433 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
" return scores\n",
|
| 435 |
"\n",
|
| 436 |
-
"
|
|
|
|
| 437 |
]
|
| 438 |
},
|
| 439 |
{
|
| 440 |
"cell_type": "markdown",
|
| 441 |
"metadata": {},
|
| 442 |
"source": [
|
| 443 |
-
"##
|
|
|
|
|
|
|
| 444 |
]
|
| 445 |
},
|
| 446 |
{
|
|
@@ -449,18 +299,63 @@
|
|
| 449 |
"metadata": {},
|
| 450 |
"outputs": [],
|
| 451 |
"source": [
|
| 452 |
-
"
|
| 453 |
-
"
|
| 454 |
-
"
|
| 455 |
-
"
|
| 456 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
]
|
| 458 |
},
|
| 459 |
{
|
| 460 |
"cell_type": "markdown",
|
| 461 |
"metadata": {},
|
| 462 |
"source": [
|
| 463 |
-
"##
|
| 464 |
]
|
| 465 |
},
|
| 466 |
{
|
|
@@ -494,7 +389,7 @@
|
|
| 494 |
"cell_type": "markdown",
|
| 495 |
"metadata": {},
|
| 496 |
"source": [
|
| 497 |
-
"##
|
| 498 |
]
|
| 499 |
},
|
| 500 |
{
|
|
@@ -523,13 +418,15 @@
|
|
| 523 |
"cell_type": "markdown",
|
| 524 |
"metadata": {},
|
| 525 |
"source": [
|
| 526 |
-
"##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
"\n",
|
| 528 |
-
"
|
| 529 |
-
"1. Pick a task prompt (rotates through triangle, half_fold, quarter_fold, accordion, waterbomb, miura_ori, letter_fold)\n",
|
| 530 |
-
"2. Generate N completions\n",
|
| 531 |
-
"3. Score each: valid_fold (+1/-0.5/-2) + shape_match (0-20)\n",
|
| 532 |
-
"4. GRPO uses within-group variance to update the policy"
|
| 533 |
]
|
| 534 |
},
|
| 535 |
{
|
|
@@ -562,12 +459,12 @@
|
|
| 562 |
"trainer = GRPOTrainer(\n",
|
| 563 |
" model=model,\n",
|
| 564 |
" processing_class=tokenizer,\n",
|
| 565 |
-
" reward_funcs=[valid_fold_reward,
|
| 566 |
" args=training_args,\n",
|
| 567 |
" train_dataset=grpo_dataset,\n",
|
| 568 |
")\n",
|
| 569 |
"\n",
|
| 570 |
-
"print(\"Starting GRPO training...\")\n",
|
| 571 |
"trainer.train()\n",
|
| 572 |
"print(\"Done!\")"
|
| 573 |
]
|
|
@@ -576,7 +473,7 @@
|
|
| 576 |
"cell_type": "markdown",
|
| 577 |
"metadata": {},
|
| 578 |
"source": [
|
| 579 |
-
"##
|
| 580 |
]
|
| 581 |
},
|
| 582 |
{
|
|
@@ -587,8 +484,8 @@
|
|
| 587 |
"source": [
|
| 588 |
"FastLanguageModel.for_inference(model)\n",
|
| 589 |
"\n",
|
| 590 |
-
"for tn in [
|
| 591 |
-
" task =
|
| 592 |
" prompt = build_prompt(task)\n",
|
| 593 |
" inputs = tokenizer.apply_chat_template(\n",
|
| 594 |
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
|
@@ -599,12 +496,14 @@
|
|
| 599 |
"\n",
|
| 600 |
" fold = extract_fold_json(response)\n",
|
| 601 |
" if fold:\n",
|
| 602 |
-
"
|
| 603 |
-
"
|
| 604 |
-
"
|
| 605 |
-
"
|
| 606 |
-
"
|
| 607 |
-
" print(f\"{tn:15s} |
|
|
|
|
|
|
|
| 608 |
" else:\n",
|
| 609 |
" print(f\"{tn:15s} | no JSON extracted\")"
|
| 610 |
]
|
|
@@ -613,7 +512,7 @@
|
|
| 613 |
"cell_type": "markdown",
|
| 614 |
"metadata": {},
|
| 615 |
"source": [
|
| 616 |
-
"##
|
| 617 |
]
|
| 618 |
},
|
| 619 |
{
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# Optigami — GRPO Training with OpenEnv\n",
|
| 8 |
"\n",
|
| 9 |
"Train an LLM to generate FOLD-format crease patterns using **GRPO** with **Unsloth** + **TRL**.\n",
|
| 10 |
"\n",
|
| 11 |
+
"**Rewards come from the OpenEnv environment** deployed on HF Spaces — the training loop\n",
|
| 12 |
+
"calls `POST /reset` and `POST /step` on the live environment to get simulation results.\n",
|
| 13 |
"\n",
|
| 14 |
"**Environment**: [openenv-community/optigami_](https://huggingface.co/spaces/openenv-community/optigami_) (OpenEnv 0.2.1)\n",
|
| 15 |
"\n",
|
|
|
|
| 45 |
"elif importlib.util.find_spec(\"unsloth\") is None:\n",
|
| 46 |
" !pip install -qqq unsloth trackio\n",
|
| 47 |
"!pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n",
|
| 48 |
+
"!pip install -qqq scipy datasets requests"
|
| 49 |
]
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "markdown",
|
| 53 |
"metadata": {},
|
| 54 |
"source": [
|
| 55 |
+
"## 2. OpenEnv Client\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
"\n",
|
| 57 |
+
"Connect to the Optigami OpenEnv environment deployed on HF Spaces.\n",
|
| 58 |
+
"The environment handles fold simulation and reward computation via its API:\n",
|
| 59 |
+
"- `POST /api/env/reset` — start a new episode for a task\n",
|
| 60 |
+
"- `POST /api/env/step` — submit a FOLD crease pattern, get back reward + shape similarity"
|
| 61 |
]
|
| 62 |
},
|
| 63 |
{
|
|
|
|
| 66 |
"metadata": {},
|
| 67 |
"outputs": [],
|
| 68 |
"source": [
|
| 69 |
+
"import requests\n",
|
| 70 |
+
"import json\n",
|
| 71 |
+
"import re\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# OpenEnv environment URL — deployed on HF Spaces\n",
|
| 74 |
+
"OPENENV_URL = \"https://openenv-community-optigami-c92c300.hf.space/api/env\"\n",
|
| 75 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"\n",
|
| 77 |
+
"class OptigamiEnvClient:\n",
|
| 78 |
+
" \"\"\"Client for the Optigami OpenEnv environment on HF Spaces.\"\"\"\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" def __init__(self, base_url=OPENENV_URL):\n",
|
| 81 |
+
" self.base_url = base_url\n",
|
| 82 |
+
" self.session = requests.Session()\n",
|
| 83 |
+
" self.session.headers.update({\"Content-Type\": \"application/json\"})\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" def health(self):\n",
|
| 86 |
+
" \"\"\"Check if the environment is reachable.\"\"\"\n",
|
| 87 |
+
" try:\n",
|
| 88 |
+
" r = self.session.get(f\"{self.base_url}/health\", timeout=10)\n",
|
| 89 |
+
" return r.status_code == 200\n",
|
| 90 |
+
" except:\n",
|
| 91 |
+
" return False\n",
|
| 92 |
+
"\n",
|
| 93 |
+
" def reset(self, task_name=\"triangle\"):\n",
|
| 94 |
+
" \"\"\"Reset environment for a task. Returns observation dict.\"\"\"\n",
|
| 95 |
+
" r = self.session.post(\n",
|
| 96 |
+
" f\"{self.base_url}/reset\",\n",
|
| 97 |
+
" json={\"task_name\": task_name},\n",
|
| 98 |
+
" timeout=30,\n",
|
| 99 |
+
" )\n",
|
| 100 |
+
" r.raise_for_status()\n",
|
| 101 |
+
" return r.json()\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" def step(self, fold_data):\n",
|
| 104 |
+
" \"\"\"Submit a FOLD crease pattern. Returns observation with reward.\"\"\"\n",
|
| 105 |
+
" r = self.session.post(\n",
|
| 106 |
+
" f\"{self.base_url}/step\",\n",
|
| 107 |
+
" json={\"action\": {\"fold_data\": fold_data}},\n",
|
| 108 |
+
" timeout=30,\n",
|
| 109 |
+
" )\n",
|
| 110 |
+
" r.raise_for_status()\n",
|
| 111 |
+
" return r.json()\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" def get_tasks(self):\n",
|
| 114 |
+
" \"\"\"Fetch available tasks from the environment.\"\"\"\n",
|
| 115 |
+
" r = self.session.get(f\"{self.base_url.replace('/api/env', '')}/tasks\", timeout=10)\n",
|
| 116 |
+
" r.raise_for_status()\n",
|
| 117 |
+
" return r.json()\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# Initialize client and verify connection\n",
|
| 121 |
+
"env = OptigamiEnvClient()\n",
|
| 122 |
+
"if env.health():\n",
|
| 123 |
+
" print(f\"Connected to OpenEnv at {OPENENV_URL}\")\n",
|
| 124 |
+
" tasks = env.get_tasks()\n",
|
| 125 |
+
" print(f\"Available tasks: {list(tasks.keys())}\")\n",
|
| 126 |
+
"else:\n",
|
| 127 |
+
" print(f\"WARNING: Cannot reach OpenEnv at {OPENENV_URL}\")\n",
|
| 128 |
+
" print(\"Make sure the HF Space is running!\")"
|
| 129 |
]
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "markdown",
|
| 133 |
"metadata": {},
|
| 134 |
"source": [
|
| 135 |
+
"## 3. Test the OpenEnv Environment\n",
|
| 136 |
"\n",
|
| 137 |
+
"Verify the full loop: reset → step → get reward."
|
| 138 |
]
|
| 139 |
},
|
| 140 |
{
|
|
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
+
"# Reset for triangle task\n",
|
| 147 |
+
"obs = env.reset(task_name=\"triangle\")\n",
|
| 148 |
+
"print(\"Reset observation:\")\n",
|
| 149 |
+
"print(f\" Task: {obs['observation']['task']}\")\n",
|
| 150 |
+
"print(f\" Done: {obs['observation']['done']}\")\n",
|
| 151 |
+
"print(f\" Target positions: {len(obs['observation']['target_positions'])} vertices\")\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"# Submit the known-correct triangle fold\n",
|
| 154 |
+
"correct_fold = {\n",
|
| 155 |
+
" \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1]],\n",
|
| 156 |
+
" \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n",
|
| 157 |
+
" \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n",
|
| 158 |
+
" \"edges_foldAngle\": [0,0,0,0,180],\n",
|
| 159 |
+
" \"faces_vertices\": [[0,1,2],[0,2,3]],\n",
|
| 160 |
+
"}\n",
|
| 161 |
+
"result = env.step(correct_fold)\n",
|
| 162 |
+
"print(f\"\\nStep result:\")\n",
|
| 163 |
+
"print(f\" Reward: {result['reward']}\")\n",
|
| 164 |
+
"print(f\" Shape similarity: {result['observation']['shape_similarity']}\")\n",
|
| 165 |
+
"print(f\" Done: {result['observation']['done']}\")\n",
|
| 166 |
+
"print(f\" Is stable: {result['observation']['is_stable']}\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
]
|
| 168 |
},
|
| 169 |
{
|
| 170 |
"cell_type": "markdown",
|
| 171 |
"metadata": {},
|
| 172 |
"source": [
|
| 173 |
+
"## 4. Reward Functions (via OpenEnv API)\n",
|
| 174 |
"\n",
|
| 175 |
+
"- **`valid_fold_reward`**: Local JSON validation (fast, no API call needed)\n",
|
| 176 |
+
"- **`openenv_reward`**: Calls the OpenEnv environment to simulate the fold and get the real reward"
|
| 177 |
]
|
| 178 |
},
|
| 179 |
{
|
|
|
|
| 185 |
"PRINTER = 0\n",
|
| 186 |
"\n",
|
| 187 |
"def extract_fold_json(response):\n",
|
| 188 |
+
" \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n",
|
| 189 |
" m = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n",
|
| 190 |
" if m:\n",
|
| 191 |
" try: return json.loads(m.group(1))\n",
|
|
|
|
| 202 |
"\n",
|
| 203 |
"\n",
|
| 204 |
"def valid_fold_reward(completions, **kwargs):\n",
|
| 205 |
+
" \"\"\"Reward 1 (local): +1.0 valid FOLD structure, -0.5 bad structure, -2.0 unparseable.\"\"\"\n",
|
| 206 |
+
" REQUIRED = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n",
|
| 207 |
" scores = []\n",
|
| 208 |
" for c in completions:\n",
|
| 209 |
" fold = extract_fold_json(c[0][\"content\"])\n",
|
| 210 |
+
" if fold is None:\n",
|
| 211 |
+
" scores.append(-2.0)\n",
|
| 212 |
+
" continue\n",
|
| 213 |
+
" # Basic structural checks\n",
|
| 214 |
+
" if not REQUIRED.issubset(fold.keys()):\n",
|
| 215 |
+
" scores.append(-0.5); continue\n",
|
| 216 |
+
" verts = fold[\"vertices_coords\"]\n",
|
| 217 |
+
" edges = fold[\"edges_vertices\"]\n",
|
| 218 |
+
" asgn = fold[\"edges_assignment\"]\n",
|
| 219 |
+
" if len(verts) < 3 or len(edges) < 3 or len(edges) != len(asgn):\n",
|
| 220 |
+
" scores.append(-0.5); continue\n",
|
| 221 |
+
" if not any(a in (\"M\",\"V\") for a in asgn):\n",
|
| 222 |
+
" scores.append(-0.5); continue\n",
|
| 223 |
+
" if not any(a == \"B\" for a in asgn):\n",
|
| 224 |
+
" scores.append(-0.5); continue\n",
|
| 225 |
+
" scores.append(1.0)\n",
|
| 226 |
" return scores\n",
|
| 227 |
"\n",
|
| 228 |
"\n",
|
| 229 |
+
"def openenv_reward(completions, task_name, **kwargs):\n",
|
| 230 |
+
" \"\"\"Reward 2 (OpenEnv API): Submit fold to environment, get simulation reward.\n",
|
| 231 |
+
"\n",
|
| 232 |
+
" Calls POST /reset and POST /step on the HF Space OpenEnv environment.\n",
|
| 233 |
+
" The environment runs the fold simulation and computes shape similarity.\n",
|
| 234 |
+
" \"\"\"\n",
|
| 235 |
" global PRINTER\n",
|
| 236 |
+
" # task_name comes as a list from the dataset\n",
|
| 237 |
+
" tn = task_name[0] if isinstance(task_name, list) else task_name\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
"\n",
|
| 239 |
" scores = []\n",
|
| 240 |
" for c in completions:\n",
|
| 241 |
" resp = c[0][\"content\"]\n",
|
| 242 |
+
"\n",
|
| 243 |
+
" # Periodic logging\n",
|
| 244 |
" if PRINTER % 10 == 0:\n",
|
| 245 |
" print(f\"\\n--- [{tn}] Sample {PRINTER} ---\\n{resp[:300]}\")\n",
|
| 246 |
" PRINTER += 1\n",
|
| 247 |
+
"\n",
|
| 248 |
+
" # Parse the FOLD JSON from the LLM response\n",
|
| 249 |
" fold = extract_fold_json(resp)\n",
|
| 250 |
+
" if fold is None:\n",
|
| 251 |
+
" scores.append(-2.0)\n",
|
| 252 |
+
" continue\n",
|
| 253 |
+
"\n",
|
| 254 |
" try:\n",
|
| 255 |
+
" # Reset environment for this task\n",
|
| 256 |
+
" env.reset(task_name=tn)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" # Submit the fold to OpenEnv — environment simulates and scores it\n",
|
| 259 |
+
" result = env.step(fold)\n",
|
| 260 |
+
"\n",
|
| 261 |
+
" # Get reward from the environment\n",
|
| 262 |
+
" reward = result.get(\"reward\", None)\n",
|
| 263 |
+
" if reward is not None:\n",
|
| 264 |
+
" scores.append(float(reward))\n",
|
| 265 |
+
" else:\n",
|
| 266 |
+
" # Fallback: extract from observation\n",
|
| 267 |
+
" obs = result.get(\"observation\", {})\n",
|
| 268 |
+
" if obs.get(\"error\"):\n",
|
| 269 |
+
" scores.append(-2.0)\n",
|
| 270 |
+
" else:\n",
|
| 271 |
+
" sim = obs.get(\"shape_similarity\", 0.0)\n",
|
| 272 |
+
" scores.append(float(sim) * 20.0)\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" except requests.exceptions.RequestException as e:\n",
|
| 275 |
+
" print(f\"OpenEnv API error: {e}\")\n",
|
| 276 |
+
" scores.append(-1.0)\n",
|
| 277 |
+
" except Exception as e:\n",
|
| 278 |
+
" print(f\"Reward error: {e}\")\n",
|
| 279 |
+
" scores.append(-1.0)\n",
|
| 280 |
+
"\n",
|
| 281 |
" return scores\n",
|
| 282 |
"\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"print(\"Reward functions ready (valid_fold=local, openenv_reward=API).\")"
|
| 285 |
]
|
| 286 |
},
|
| 287 |
{
|
| 288 |
"cell_type": "markdown",
|
| 289 |
"metadata": {},
|
| 290 |
"source": [
|
| 291 |
+
"## 5. Prompt Template & Multi-Task Dataset\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"Each row is a different task — the model must generalize across shapes."
|
| 294 |
]
|
| 295 |
},
|
| 296 |
{
|
|
|
|
| 299 |
"metadata": {},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
+
"from datasets import Dataset\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"PROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\n",
|
| 305 |
+
"that, when folded, produces the target shape described below.\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"Target: {description}\n",
|
| 308 |
+
"Paper size: {width} x {height}\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"Output a JSON object with these exact fields:\n",
|
| 311 |
+
"- vertices_coords: [[x, y], ...] — 2D positions on the flat paper\n",
|
| 312 |
+
"- edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges\n",
|
| 313 |
+
"- edges_assignment: [\"B\"|\"M\"|\"V\", ...] — B=boundary, M=mountain fold, V=valley fold\n",
|
| 314 |
+
"- edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative, V: positive, B: 0)\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"Rules:\n",
|
| 317 |
+
"- Boundary edges (B) must outline the paper rectangle\n",
|
| 318 |
+
"- At least one fold crease (M or V) must exist\n",
|
| 319 |
+
"- Mountain fold angles are negative (-180 to 0)\n",
|
| 320 |
+
"- Valley fold angles are positive (0 to 180)\n",
|
| 321 |
+
"- All vertex indices in edges must be valid (0 to N-1)\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"Output ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"def build_prompt(task):\n",
|
| 327 |
+
" return PROMPT_TEMPLATE.format(\n",
|
| 328 |
+
" description=task[\"description\"],\n",
|
| 329 |
+
" width=task[\"paper\"][\"width\"],\n",
|
| 330 |
+
" height=task[\"paper\"][\"height\"],\n",
|
| 331 |
+
" )\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"# Fetch tasks from the OpenEnv environment\n",
|
| 335 |
+
"env_tasks = env.get_tasks()\n",
|
| 336 |
+
"task_names = list(env_tasks.keys())\n",
|
| 337 |
+
"print(f\"Tasks from OpenEnv: {task_names}\")\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# Build multi-task dataset\n",
|
| 340 |
+
"rows = []\n",
|
| 341 |
+
"for i in range(1000):\n",
|
| 342 |
+
" tn = task_names[i % len(task_names)]\n",
|
| 343 |
+
" t = env_tasks[tn]\n",
|
| 344 |
+
" rows.append({\n",
|
| 345 |
+
" \"prompt\": [{\"role\": \"user\", \"content\": build_prompt(t)}],\n",
|
| 346 |
+
" \"task_name\": tn,\n",
|
| 347 |
+
" \"answer\": 0,\n",
|
| 348 |
+
" })\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"grpo_dataset = Dataset.from_list(rows)\n",
|
| 351 |
+
"print(f\"Dataset: {len(grpo_dataset)} rows across {len(task_names)} tasks\")"
|
| 352 |
]
|
| 353 |
},
|
| 354 |
{
|
| 355 |
"cell_type": "markdown",
|
| 356 |
"metadata": {},
|
| 357 |
"source": [
|
| 358 |
+
"## 6. Configuration"
|
| 359 |
]
|
| 360 |
},
|
| 361 |
{
|
|
|
|
| 389 |
"cell_type": "markdown",
|
| 390 |
"metadata": {},
|
| 391 |
"source": [
|
| 392 |
+
"## 7. Load Model + LoRA"
|
| 393 |
]
|
| 394 |
},
|
| 395 |
{
|
|
|
|
| 418 |
"cell_type": "markdown",
|
| 419 |
"metadata": {},
|
| 420 |
"source": [
|
| 421 |
+
"## 8. GRPO Training\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"Each training step:\n",
|
| 424 |
+
"1. Pick a task prompt (rotates through all 7 origami tasks)\n",
|
| 425 |
+
"2. Generate N completions (FOLD JSON attempts)\n",
|
| 426 |
+
"3. Score each: `valid_fold_reward` (local) + `openenv_reward` (calls HF Space API)\n",
|
| 427 |
+
"4. GRPO computes group-relative advantages and updates the policy\n",
|
| 428 |
"\n",
|
| 429 |
+
"The **OpenEnv environment** on HF Spaces handles the fold simulation and shape comparison."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
]
|
| 431 |
},
|
| 432 |
{
|
|
|
|
| 459 |
"trainer = GRPOTrainer(\n",
|
| 460 |
" model=model,\n",
|
| 461 |
" processing_class=tokenizer,\n",
|
| 462 |
+
" reward_funcs=[valid_fold_reward, openenv_reward],\n",
|
| 463 |
" args=training_args,\n",
|
| 464 |
" train_dataset=grpo_dataset,\n",
|
| 465 |
")\n",
|
| 466 |
"\n",
|
| 467 |
+
"print(\"Starting GRPO training (rewards from OpenEnv API)...\")\n",
|
| 468 |
"trainer.train()\n",
|
| 469 |
"print(\"Done!\")"
|
| 470 |
]
|
|
|
|
| 473 |
"cell_type": "markdown",
|
| 474 |
"metadata": {},
|
| 475 |
"source": [
|
| 476 |
+
"## 9. Test Across Tasks"
|
| 477 |
]
|
| 478 |
},
|
| 479 |
{
|
|
|
|
| 484 |
"source": [
|
| 485 |
"FastLanguageModel.for_inference(model)\n",
|
| 486 |
"\n",
|
| 487 |
+
"for tn in list(env_tasks.keys())[:4]:\n",
|
| 488 |
+
" task = env_tasks[tn]\n",
|
| 489 |
" prompt = build_prompt(task)\n",
|
| 490 |
" inputs = tokenizer.apply_chat_template(\n",
|
| 491 |
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
|
|
|
| 496 |
"\n",
|
| 497 |
" fold = extract_fold_json(response)\n",
|
| 498 |
" if fold:\n",
|
| 499 |
+
" try:\n",
|
| 500 |
+
" env.reset(task_name=tn)\n",
|
| 501 |
+
" result = env.step(fold)\n",
|
| 502 |
+
" reward = result.get(\"reward\", 0)\n",
|
| 503 |
+
" sim = result[\"observation\"].get(\"shape_similarity\", 0)\n",
|
| 504 |
+
" print(f\"{tn:15s} | reward={reward:6.2f} | similarity={sim:.3f}\")\n",
|
| 505 |
+
" except Exception as e:\n",
|
| 506 |
+
" print(f\"{tn:15s} | env error: {e}\")\n",
|
| 507 |
" else:\n",
|
| 508 |
" print(f\"{tn:15s} | no JSON extracted\")"
|
| 509 |
]
|
|
|
|
| 512 |
"cell_type": "markdown",
|
| 513 |
"metadata": {},
|
| 514 |
"source": [
|
| 515 |
+
"## 10. Save Model"
|
| 516 |
]
|
| 517 |
},
|
| 518 |
{
|