sissississi Claude Opus 4.6 commited on
Commit
bc52096
·
1 Parent(s): 654687c

Add RL training environment with OpenEnv backend

Browse files

- origami_server/: Full OpenEnv environment with physics simulator
- engine/: FOLD parser, analytical fold simulator (BFS + rotation
transforms), chamfer-distance shape matching
- environment.py: OrigamiEnvironment (reset/step/state)
- tasks.py: 7 tasks (triangle, half_fold, quarter_fold, letter_fold,
waterbomb, accordion, miura_ori)
- app.py: FastAPI server with /tasks endpoints
- training/: GRPO training scripts for Colab
- reward.py: valid_fold + shape_match reward functions
- train_grpo.py: Unsloth + TRL GRPOTrainer pipeline
- Dual-service Docker: Next.js UI on :7860, OpenEnv API on :8000

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.dockerignore CHANGED
@@ -3,3 +3,4 @@ node_modules
3
  .git
4
  __pycache__
5
  *.md
 
 
3
  .git
4
  __pycache__
5
  *.md
6
+ outputs/
Dockerfile CHANGED
@@ -1,39 +1,54 @@
1
  FROM node:20-alpine AS base
2
 
3
- # --- Dependencies ---
4
  FROM base AS deps
5
  WORKDIR /app
6
  COPY package.json package-lock.json ./
7
  RUN npm ci
8
 
9
- # --- Build ---
10
  FROM base AS builder
11
  WORKDIR /app
12
  COPY --from=deps /app/node_modules ./node_modules
13
- COPY . .
14
-
 
 
 
 
15
  ENV NEXT_TELEMETRY_DISABLED=1
16
  RUN npm run build
17
 
18
- # --- Runner ---
19
- FROM base AS runner
20
  WORKDIR /app
21
 
 
 
 
22
  ENV NODE_ENV=production
23
  ENV NEXT_TELEMETRY_DISABLED=1
24
  ENV PORT=7860
25
  ENV HOSTNAME=0.0.0.0
26
 
27
- RUN addgroup --system --gid 1001 nodejs
28
- RUN adduser --system --uid 1001 nextjs
29
-
30
  COPY --from=builder /app/public ./public
31
- COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone ./
32
- COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
 
 
 
 
 
 
 
 
33
 
34
- USER nextjs
 
 
35
 
36
  EXPOSE 7860
 
37
 
38
- # Start the Next.js standalone server
39
- CMD ["node", "server.js"]
 
1
  FROM node:20-alpine AS base
2
 
3
+ # --- Node dependencies ---
4
  FROM base AS deps
5
  WORKDIR /app
6
  COPY package.json package-lock.json ./
7
  RUN npm ci
8
 
9
+ # --- Next.js build ---
10
  FROM base AS builder
11
  WORKDIR /app
12
  COPY --from=deps /app/node_modules ./node_modules
13
+ COPY app/ ./app/
14
+ COPY components/ ./components/
15
+ COPY hooks/ ./hooks/
16
+ COPY lib/ ./lib/
17
+ COPY public/ ./public/
18
+ COPY package.json package-lock.json tsconfig.json next.config.ts postcss.config.mjs eslint.config.mjs .eslintrc.json next-env.d.ts metadata.json ./
19
  ENV NEXT_TELEMETRY_DISABLED=1
20
  RUN npm run build
21
 
22
+ # --- Final runner: Node + Python ---
23
+ FROM node:20-alpine AS runner
24
  WORKDIR /app
25
 
26
+ # Install Python + pip for the OpenEnv API
27
+ RUN apk add --no-cache python3 py3-pip py3-numpy py3-scipy
28
+
29
  ENV NODE_ENV=production
30
  ENV NEXT_TELEMETRY_DISABLED=1
31
  ENV PORT=7860
32
  ENV HOSTNAME=0.0.0.0
33
 
34
+ # Copy Next.js standalone build
 
 
35
  COPY --from=builder /app/public ./public
36
+ COPY --from=builder /app/.next/standalone ./
37
+ COPY --from=builder /app/.next/static ./.next/static
38
+
39
+ # Copy Python RL environment
40
+ COPY origami_server/ ./origami_server/
41
+ COPY training/ ./training/
42
+ COPY requirements.txt ./
43
+
44
+ # Install Python deps
45
+ RUN pip3 install --no-cache-dir --break-system-packages -r requirements.txt
46
 
47
+ # Start script: run both services
48
+ COPY start.sh ./
49
+ RUN chmod +x start.sh
50
 
51
  EXPOSE 7860
52
+ EXPOSE 8000
53
 
54
+ CMD ["./start.sh"]
 
README.md CHANGED
@@ -7,5 +7,5 @@ sdk: docker
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
- short_description: Interactive 3D origami simulator with AI pattern generation
11
  ---
 
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
+ short_description: Interactive 3D origami simulator + RL training environment
11
  ---
origami_server/__init__.py ADDED
File without changes
origami_server/app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI entry point — OpenEnv create_app() + custom endpoints."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from fastapi.responses import HTMLResponse
7
+ from fastapi.staticfiles import StaticFiles
8
+
9
+ from openenv.core.env_server.http_server import create_app
10
+
11
+ from .environment import OrigamiEnvironment
12
+ from .models import OrigamiAction, OrigamiObservation
13
+
14
+ app = create_app(
15
+ OrigamiEnvironment,
16
+ OrigamiAction,
17
+ OrigamiObservation,
18
+ env_name="origami_env",
19
+ )
20
+
21
+ from .tasks import TASKS
22
+
23
+
24
+ @app.get("/tasks")
25
+ def get_tasks():
26
+ return {
27
+ name: {
28
+ "name": task["name"],
29
+ "description": task["description"],
30
+ "difficulty": task["difficulty"],
31
+ "paper": task["paper"],
32
+ "target_fold": task["target_fold"],
33
+ }
34
+ for name, task in TASKS.items()
35
+ }
36
+
37
+
38
+ @app.get("/tasks/{task_name}")
39
+ def get_task_detail(task_name: str):
40
+ if task_name not in TASKS:
41
+ from fastapi import HTTPException
42
+
43
+ raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
44
+ task = TASKS[task_name]
45
+ return {
46
+ "name": task["name"],
47
+ "description": task["description"],
48
+ "difficulty": task["difficulty"],
49
+ "paper": task["paper"],
50
+ "target_fold": task["target_fold"],
51
+ }
52
+
53
+
54
+ def main():
55
+ import uvicorn
56
+
57
+ port = int(os.environ.get("PORT", 8000))
58
+ uvicorn.run(app, host="0.0.0.0", port=port)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
origami_server/engine/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .simulate import simulate
2
+ from .fold_parser import parse_fold, validate_fold
3
+ from .shape_match import compute_shape_match
origami_server/engine/fold_parser.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FOLD JSON parsing and validation.
2
+
3
+ Validates LLM-generated FOLD crease patterns before simulation.
4
+ FOLD spec: https://github.com/edemaine/fold
5
+ """
6
+
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+
12
+ def validate_fold(fold_data: dict[str, Any]) -> tuple[bool, str]:
13
+ """Validate a FOLD JSON object. Returns (is_valid, error_message)."""
14
+
15
+ # Required fields
16
+ for key in ("vertices_coords", "edges_vertices", "edges_assignment"):
17
+ if key not in fold_data:
18
+ return False, f"Missing required field: {key}"
19
+
20
+ verts = fold_data["vertices_coords"]
21
+ edges = fold_data["edges_vertices"]
22
+ assignments = fold_data["edges_assignment"]
23
+
24
+ # Must have at least 3 vertices (a triangle)
25
+ if len(verts) < 3:
26
+ return False, f"Need at least 3 vertices, got {len(verts)}"
27
+
28
+ # Must have at least 3 edges
29
+ if len(edges) < 3:
30
+ return False, f"Need at least 3 edges, got {len(edges)}"
31
+
32
+ # Edges and assignments must match length
33
+ if len(edges) != len(assignments):
34
+ return False, (
35
+ f"edges_vertices ({len(edges)}) and "
36
+ f"edges_assignment ({len(assignments)}) must match length"
37
+ )
38
+
39
+ # Fold angles must match if present
40
+ if "edges_foldAngle" in fold_data:
41
+ angles = fold_data["edges_foldAngle"]
42
+ if len(angles) != len(edges):
43
+ return False, (
44
+ f"edges_foldAngle ({len(angles)}) must match "
45
+ f"edges_vertices ({len(edges)})"
46
+ )
47
+
48
+ # Validate vertex coordinates (2D or 3D)
49
+ num_verts = len(verts)
50
+ for i, v in enumerate(verts):
51
+ if not isinstance(v, (list, tuple)) or len(v) < 2:
52
+ return False, f"Vertex {i} must be [x, y] or [x, y, z], got {v}"
53
+
54
+ # Validate edge indices
55
+ for i, e in enumerate(edges):
56
+ if not isinstance(e, (list, tuple)) or len(e) != 2:
57
+ return False, f"Edge {i} must be [v1, v2], got {e}"
58
+ v1, v2 = e
59
+ if v1 < 0 or v1 >= num_verts or v2 < 0 or v2 >= num_verts:
60
+ return False, f"Edge {i} references invalid vertex: {e}"
61
+ if v1 == v2:
62
+ return False, f"Edge {i} is degenerate (same vertex): {e}"
63
+
64
+ # Validate assignments
65
+ valid_assignments = {"M", "V", "B", "F", "U", "C"}
66
+ for i, a in enumerate(assignments):
67
+ if a not in valid_assignments:
68
+ return False, f"Edge {i} has invalid assignment '{a}'"
69
+
70
+ # Must have at least one fold crease (M or V)
71
+ has_fold = any(a in ("M", "V") for a in assignments)
72
+ if not has_fold:
73
+ return False, "No fold creases (M or V) found"
74
+
75
+ # Must have boundary edges
76
+ has_boundary = any(a == "B" for a in assignments)
77
+ if not has_boundary:
78
+ return False, "No boundary edges (B) found"
79
+
80
+ return True, ""
81
+
82
+
83
+ def parse_fold(fold_data: dict[str, Any]) -> dict[str, np.ndarray]:
84
+ """Parse validated FOLD JSON into numpy arrays for simulation.
85
+
86
+ Returns dict with:
87
+ vertices: (N, 3) float64 — vertex positions (z=0 for 2D input)
88
+ edges: (E, 2) int — edge vertex indices
89
+ assignments: list[str] — edge type per edge
90
+ fold_angles: (E,) float64 — target fold angle per edge (radians)
91
+ faces: (F, 3) int — triangulated face vertex indices
92
+ """
93
+ verts = fold_data["vertices_coords"]
94
+
95
+ # Ensure 3D (add z=0 if 2D)
96
+ vertices = np.zeros((len(verts), 3), dtype=np.float64)
97
+ for i, v in enumerate(verts):
98
+ vertices[i, 0] = v[0]
99
+ vertices[i, 1] = v[1]
100
+ if len(v) > 2:
101
+ vertices[i, 2] = v[2]
102
+
103
+ edges = np.array(fold_data["edges_vertices"], dtype=np.int32)
104
+ assignments = list(fold_data["edges_assignment"])
105
+
106
+ # Fold angles: default based on assignment if not provided
107
+ if "edges_foldAngle" in fold_data:
108
+ fold_angles = np.array(fold_data["edges_foldAngle"], dtype=np.float64)
109
+ else:
110
+ fold_angles = np.zeros(len(edges), dtype=np.float64)
111
+ for i, a in enumerate(assignments):
112
+ if a == "V":
113
+ fold_angles[i] = 180.0
114
+ elif a == "M":
115
+ fold_angles[i] = -180.0
116
+
117
+ # Convert degrees to radians for simulation
118
+ fold_angles_rad = np.radians(fold_angles)
119
+
120
+ # Triangulate faces
121
+ if "faces_vertices" in fold_data:
122
+ raw_faces = fold_data["faces_vertices"]
123
+ faces = _triangulate_faces(raw_faces)
124
+ else:
125
+ faces = _compute_faces(vertices, edges)
126
+
127
+ return {
128
+ "vertices": vertices,
129
+ "edges": edges,
130
+ "assignments": assignments,
131
+ "fold_angles": fold_angles_rad,
132
+ "faces": faces,
133
+ }
134
+
135
+
136
+ def _triangulate_faces(raw_faces: list[list[int]]) -> np.ndarray:
137
+ """Fan-triangulate polygon faces into triangles."""
138
+ triangles = []
139
+ for face in raw_faces:
140
+ if len(face) < 3:
141
+ continue
142
+ for i in range(1, len(face) - 1):
143
+ triangles.append([face[0], face[i], face[i + 1]])
144
+ if not triangles:
145
+ return np.zeros((0, 3), dtype=np.int32)
146
+ return np.array(triangles, dtype=np.int32)
147
+
148
+
149
+ def _compute_faces(vertices: np.ndarray, edges: np.ndarray) -> np.ndarray:
150
+ """Compute triangulated faces from vertices and edges using adjacency."""
151
+ from collections import defaultdict
152
+
153
+ adj = defaultdict(set)
154
+ for v1, v2 in edges:
155
+ adj[v1].add(v2)
156
+ adj[v2].add(v1)
157
+
158
+ triangles = set()
159
+ for v1, v2 in edges:
160
+ common = adj[v1] & adj[v2]
161
+ for v3 in common:
162
+ tri = tuple(sorted([v1, v2, v3]))
163
+ triangles.add(tri)
164
+
165
+ if not triangles:
166
+ # Fallback: create faces using Delaunay on 2D projection
167
+ from scipy.spatial import Delaunay
168
+
169
+ pts_2d = vertices[:, :2]
170
+ try:
171
+ tri = Delaunay(pts_2d)
172
+ return tri.simplices.astype(np.int32)
173
+ except Exception:
174
+ return np.zeros((0, 3), dtype=np.int32)
175
+
176
+ return np.array(list(triangles), dtype=np.int32)
origami_server/engine/shape_match.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shape matching for reward computation.
2
+
3
+ Computes similarity between the LLM's folded shape and the target shape.
4
+ Like AlphaFold's RMSD but for origami vertex positions.
5
+ """
6
+
7
+ import numpy as np
8
+ from scipy.spatial.distance import cdist
9
+
10
+
11
+ def compute_shape_match(
12
+ predicted: np.ndarray,
13
+ target: np.ndarray,
14
+ ) -> float:
15
+ """Compute shape similarity between predicted and target positions.
16
+
17
+ Uses chamfer distance normalized by bounding box diagonal.
18
+ Aligns shapes by centering before comparison.
19
+
20
+ Args:
21
+ predicted: (N, 3) predicted vertex positions.
22
+ target: (M, 3) target vertex positions.
23
+
24
+ Returns:
25
+ Similarity score in [0, 1]. 1.0 = perfect match.
26
+ """
27
+ if len(predicted) == 0 or len(target) == 0:
28
+ return 0.0
29
+
30
+ # Center both point clouds
31
+ pred_centered = predicted - predicted.mean(axis=0)
32
+ target_centered = target - target.mean(axis=0)
33
+
34
+ # Try multiple rotations and pick best match
35
+ best_score = 0.0
36
+ for rotation in _get_alignment_rotations():
37
+ rotated = pred_centered @ rotation.T
38
+ score = _chamfer_similarity(rotated, target_centered)
39
+ best_score = max(best_score, score)
40
+
41
+ return best_score
42
+
43
+
44
+ def _chamfer_similarity(a: np.ndarray, b: np.ndarray) -> float:
45
+ """Chamfer distance converted to similarity score."""
46
+ d = cdist(a, b)
47
+
48
+ # Forward: for each point in a, min distance to b
49
+ forward = d.min(axis=1).mean()
50
+ # Backward: for each point in b, min distance to a
51
+ backward = d.min(axis=0).mean()
52
+ chamfer = (forward + backward) / 2.0
53
+
54
+ # Normalize by bounding box diagonal of target
55
+ all_pts = np.vstack([a, b])
56
+ bbox_diag = np.linalg.norm(all_pts.max(axis=0) - all_pts.min(axis=0))
57
+ if bbox_diag < 1e-12:
58
+ return 1.0 if chamfer < 1e-12 else 0.0
59
+
60
+ similarity = max(0.0, 1.0 - chamfer / bbox_diag)
61
+ return similarity
62
+
63
+
64
+ def _get_alignment_rotations() -> list[np.ndarray]:
65
+ """Generate rotation matrices for alignment search.
66
+
67
+ Identity + 90 deg rotations around each axis + mirrors (15 total).
68
+ """
69
+ I = np.eye(3)
70
+ rotations = [I]
71
+
72
+ # 90 deg rotations around Z axis
73
+ for k in range(1, 4):
74
+ angle = k * np.pi / 2
75
+ c, s = np.cos(angle), np.sin(angle)
76
+ rotations.append(np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]))
77
+
78
+ # 90 deg rotations around X axis
79
+ for k in range(1, 4):
80
+ angle = k * np.pi / 2
81
+ c, s = np.cos(angle), np.sin(angle)
82
+ rotations.append(np.array([[1, 0, 0], [0, c, -s], [0, s, c]]))
83
+
84
+ # 90 deg rotations around Y axis
85
+ for k in range(1, 4):
86
+ angle = k * np.pi / 2
87
+ c, s = np.cos(angle), np.sin(angle)
88
+ rotations.append(np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]))
89
+
90
+ # Mirrors
91
+ rotations.append(np.diag([-1.0, 1.0, 1.0]))
92
+ rotations.append(np.diag([1.0, -1.0, 1.0]))
93
+ rotations.append(np.diag([1.0, 1.0, -1.0]))
94
+
95
+ return rotations
origami_server/engine/simulate.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Origami fold simulator — analytical rotation with cumulative transforms.
2
+
3
+ BFS from face 0 through the face adjacency graph. Each face accumulates
4
+ a rotation transform (R, t) such that: folded_pos = R @ flat_pos + t.
5
+ When crossing a fold edge, the fold rotation is composed with the parent
6
+ face's transform. Non-fold edges inherit the parent's transform directly.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+ from scipy.spatial.transform import Rotation
13
+
14
+ from .fold_parser import parse_fold
15
+
16
+
17
+ @dataclass
18
+ class SimResult:
19
+ """Result of a fold simulation."""
20
+
21
+ positions: np.ndarray # (N, 3) final vertex positions
22
+ converged: bool
23
+ steps_taken: int
24
+ max_strain: float
25
+ total_energy: float
26
+
27
+
28
+ def simulate(
29
+ fold_data: dict,
30
+ crease_percent: float = 1.0,
31
+ max_steps: int = 500,
32
+ params: dict | None = None,
33
+ ) -> SimResult:
34
+ """Simulate a FOLD crease pattern and return final 3D positions."""
35
+ parsed = parse_fold(fold_data)
36
+ flat_pos = parsed["vertices"].copy()
37
+ edges = parsed["edges"]
38
+ assignments = parsed["assignments"]
39
+ fold_angles = parsed["fold_angles"]
40
+ faces = parsed["faces"]
41
+ positions = flat_pos.copy()
42
+
43
+ if len(faces) == 0:
44
+ return SimResult(
45
+ positions=positions, converged=True,
46
+ steps_taken=0, max_strain=0.0, total_energy=0.0,
47
+ )
48
+
49
+ face_adj = _build_face_adjacency(faces)
50
+
51
+ crease_map: dict[tuple[int, int], float] = {}
52
+ for i, (v1, v2) in enumerate(edges):
53
+ key = (min(int(v1), int(v2)), max(int(v1), int(v2)))
54
+ if assignments[i] in ("M", "V"):
55
+ crease_map[key] = fold_angles[i] * crease_percent
56
+
57
+ n_faces = len(faces)
58
+ face_R = [None] * n_faces
59
+ face_t = [None] * n_faces
60
+
61
+ face_R[0] = np.eye(3)
62
+ face_t[0] = np.zeros(3)
63
+
64
+ visited = [False] * n_faces
65
+ visited[0] = True
66
+
67
+ placed: set[int] = set()
68
+ for vi in faces[0]:
69
+ placed.add(int(vi))
70
+
71
+ queue = [0]
72
+ while queue:
73
+ fi = queue.pop(0)
74
+ face = faces[fi]
75
+
76
+ for j in range(len(face)):
77
+ v1, v2 = int(face[j]), int(face[(j + 1) % len(face)])
78
+ edge_key = (min(v1, v2), max(v1, v2))
79
+
80
+ for fj in face_adj.get(edge_key, []):
81
+ if visited[fj]:
82
+ continue
83
+ visited[fj] = True
84
+ queue.append(fj)
85
+
86
+ angle = crease_map.get(edge_key, 0.0)
87
+
88
+ if abs(angle) > 1e-10:
89
+ p1 = positions[v1].copy()
90
+ axis = positions[v2] - p1
91
+ axis_len = np.linalg.norm(axis)
92
+ if axis_len > 1e-12:
93
+ axis_unit = axis / axis_len
94
+ fold_rot = Rotation.from_rotvec(
95
+ angle * axis_unit,
96
+ ).as_matrix()
97
+ else:
98
+ fold_rot = np.eye(3)
99
+
100
+ face_R[fj] = fold_rot @ face_R[fi]
101
+ face_t[fj] = fold_rot @ (face_t[fi] - p1) + p1
102
+ else:
103
+ face_R[fj] = face_R[fi].copy()
104
+ face_t[fj] = face_t[fi].copy()
105
+
106
+ for vi in faces[fj]:
107
+ vi_int = int(vi)
108
+ if vi_int not in placed:
109
+ positions[vi_int] = face_R[fj] @ flat_pos[vi_int] + face_t[fj]
110
+ placed.add(vi_int)
111
+
112
+ max_strain = _compute_strain(positions, parsed)
113
+
114
+ return SimResult(
115
+ positions=positions,
116
+ converged=True,
117
+ steps_taken=1,
118
+ max_strain=max_strain,
119
+ total_energy=0.0,
120
+ )
121
+
122
+
123
+ def _build_face_adjacency(
124
+ faces: np.ndarray,
125
+ ) -> dict[tuple[int, int], list[int]]:
126
+ """Map each edge (sorted vertex pair) to list of face indices."""
127
+ adj: dict[tuple[int, int], list[int]] = {}
128
+ for fi, face in enumerate(faces):
129
+ n = len(face)
130
+ for j in range(n):
131
+ v1, v2 = int(face[j]), int(face[(j + 1) % n])
132
+ key = (min(v1, v2), max(v1, v2))
133
+ if key not in adj:
134
+ adj[key] = []
135
+ adj[key].append(fi)
136
+ return adj
137
+
138
+
139
+ def _compute_strain(positions: np.ndarray, parsed: dict) -> float:
140
+ """Compute max axial strain across all edges."""
141
+ edges = parsed["edges"]
142
+ vertices_flat = parsed["vertices"]
143
+ max_strain = 0.0
144
+ for v1, v2 in edges:
145
+ rest = np.linalg.norm(vertices_flat[v2] - vertices_flat[v1])
146
+ curr = np.linalg.norm(positions[v2] - positions[v1])
147
+ if rest > 1e-12:
148
+ strain = abs(curr - rest) / rest
149
+ max_strain = max(max_strain, strain)
150
+ return max_strain
origami_server/environment.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Origami RL Environment — OpenEnv Environment subclass.
2
+
3
+ Single-shot episodes: LLM submits a FOLD crease pattern, physics simulates it,
4
+ reward = shape similarity to target. Like AlphaFold for origami.
5
+ """
6
+
7
+ import uuid
8
+ from typing import Any, Optional
9
+
10
+ import numpy as np
11
+ from openenv.core import Environment
12
+
13
+ from .engine.fold_parser import validate_fold
14
+ from .engine.shape_match import compute_shape_match
15
+ from .engine.simulate import SimResult, simulate
16
+ from .models import OrigamiAction, OrigamiObservation, OrigamiState
17
+ from .tasks import get_task
18
+
19
+
20
+ class OrigamiEnvironment(
21
+ Environment[OrigamiAction, OrigamiObservation, OrigamiState]
22
+ ):
23
+ """Origami folding environment.
24
+
25
+ Episode flow:
26
+ 1. reset(task_name="triangle") -> returns task description + target info
27
+ 2. step(OrigamiAction(fold_data={...})) -> simulates, scores, returns done=True
28
+
29
+ Single action per episode. The action IS the complete crease pattern.
30
+ """
31
+
32
+ SUPPORTS_CONCURRENT_SESSIONS = True
33
+
34
+ def __init__(self, **kwargs: Any):
35
+ super().__init__(**kwargs)
36
+ self._state = OrigamiState()
37
+ self._task: dict = {}
38
+ self._target_positions: np.ndarray = np.zeros((0, 3))
39
+
40
+ def reset(
41
+ self,
42
+ seed: Optional[int] = None,
43
+ episode_id: Optional[str] = None,
44
+ **kwargs: Any,
45
+ ) -> OrigamiObservation:
46
+ """Start a new episode with a target shape task."""
47
+ self._state = OrigamiState(
48
+ episode_id=episode_id or str(uuid.uuid4()),
49
+ step_count=0,
50
+ )
51
+
52
+ task_name = kwargs.get("task_name", "triangle")
53
+ self._task = get_task(task_name)
54
+ self._state.task_name = self._task["name"]
55
+
56
+ target_fold = self._task["target_fold"]
57
+ try:
58
+ target_result = simulate(target_fold, crease_percent=1.0)
59
+ self._target_positions = target_result.positions
60
+ except Exception:
61
+ self._target_positions = np.zeros((0, 3))
62
+
63
+ return OrigamiObservation(
64
+ done=False,
65
+ reward=None,
66
+ task=self._task_info(),
67
+ fold_data={},
68
+ final_positions=[],
69
+ target_positions=self._target_positions.tolist(),
70
+ shape_similarity=0.0,
71
+ max_strain=0.0,
72
+ is_stable=True,
73
+ error=None,
74
+ )
75
+
76
+ def step(
77
+ self,
78
+ action: OrigamiAction,
79
+ timeout_s: Optional[float] = None,
80
+ **kwargs: Any,
81
+ ) -> OrigamiObservation:
82
+ """Evaluate the LLM's crease pattern.
83
+
84
+ 1. Validate FOLD data
85
+ 2. Run physics simulation (creasePercent=1.0)
86
+ 3. Compare final shape to target
87
+ 4. Return observation with reward = similarity * 20
88
+ """
89
+ self._state.step_count += 1
90
+ fold_data = action.fold_data
91
+
92
+ is_valid, error_msg = validate_fold(fold_data)
93
+ if not is_valid:
94
+ self._state.is_stable = False
95
+ return OrigamiObservation(
96
+ done=True,
97
+ reward=-2.0,
98
+ task=self._task_info(),
99
+ fold_data=fold_data,
100
+ final_positions=[],
101
+ target_positions=self._target_positions.tolist(),
102
+ shape_similarity=0.0,
103
+ max_strain=0.0,
104
+ is_stable=False,
105
+ error=f"Invalid FOLD data: {error_msg}",
106
+ )
107
+
108
+ try:
109
+ result: SimResult = simulate(fold_data, crease_percent=1.0)
110
+ except Exception as e:
111
+ self._state.is_stable = False
112
+ return OrigamiObservation(
113
+ done=True,
114
+ reward=-2.0,
115
+ task=self._task_info(),
116
+ fold_data=fold_data,
117
+ final_positions=[],
118
+ target_positions=self._target_positions.tolist(),
119
+ shape_similarity=0.0,
120
+ max_strain=0.0,
121
+ is_stable=False,
122
+ error=f"Simulation error: {str(e)}",
123
+ )
124
+
125
+ similarity = compute_shape_match(
126
+ result.positions, self._target_positions
127
+ )
128
+ reward = similarity * 20.0
129
+
130
+ self._state.shape_similarity = similarity
131
+ self._state.is_stable = result.converged
132
+
133
+ return OrigamiObservation(
134
+ done=True,
135
+ reward=reward,
136
+ task=self._task_info(),
137
+ fold_data=fold_data,
138
+ final_positions=result.positions.tolist(),
139
+ target_positions=self._target_positions.tolist(),
140
+ shape_similarity=similarity,
141
+ max_strain=result.max_strain,
142
+ is_stable=result.converged,
143
+ error=None,
144
+ )
145
+
146
+ @property
147
+ def state(self) -> OrigamiState:
148
+ return self._state
149
+
150
+ def _task_info(self) -> dict:
151
+ if not self._task:
152
+ return {}
153
+ return {
154
+ "name": self._task.get("name", ""),
155
+ "description": self._task.get("description", ""),
156
+ "difficulty": self._task.get("difficulty", 0),
157
+ "paper": self._task.get("paper", {}),
158
+ }
origami_server/models.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv types for the Origami RL environment.
2
+
3
+ OrigamiAction: LLM submits a FOLD crease pattern.
4
+ OrigamiObservation: Result of simulating that pattern against a target.
5
+ OrigamiState: Internal episode state.
6
+ """
7
+
8
+ from typing import Any, Optional
9
+
10
+ from openenv.core import Action, Observation, State
11
+ from pydantic import Field
12
+
13
+
14
+ class OrigamiAction(Action):
15
+ """LLM submits a FOLD crease pattern as its action.
16
+
17
+ The fold_data dict must contain:
18
+ - vertices_coords: [[x, y], ...] — 2D vertex positions on flat paper
19
+ - edges_vertices: [[v1, v2], ...] — edge connectivity
20
+ - edges_assignment: ["B"|"M"|"V", ...] — boundary/mountain/valley
21
+ - edges_foldAngle: [angle, ...] — target fold angles in degrees
22
+ (optional — defaults from assignment: M=-180, V=+180, B=0)
23
+ """
24
+
25
+ fold_data: dict[str, Any] = Field(
26
+ ..., description="FOLD-format crease pattern JSON"
27
+ )
28
+
29
+
30
+ class OrigamiObservation(Observation):
31
+ """Result of simulating the LLM's crease pattern."""
32
+
33
+ task: dict[str, Any] = Field(default_factory=dict)
34
+ fold_data: dict[str, Any] = Field(default_factory=dict)
35
+ final_positions: list[list[float]] = Field(default_factory=list)
36
+ target_positions: list[list[float]] = Field(default_factory=list)
37
+ shape_similarity: float = 0.0
38
+ max_strain: float = 0.0
39
+ is_stable: bool = True
40
+ error: Optional[str] = None
41
+
42
+
43
+ class OrigamiState(State):
44
+ """Internal state for an origami episode."""
45
+
46
+ task_name: str = ""
47
+ shape_similarity: float = 0.0
48
+ is_stable: bool = True
origami_server/tasks.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task definitions for origami RL training.
2
+
3
+ Each task defines a target shape as a reference FOLD crease pattern.
4
+ The LLM must discover a crease pattern that folds into the same shape.
5
+ """
6
+
7
+ TASKS: dict[str, dict] = {
8
+ "triangle": {
9
+ "name": "triangle",
10
+ "description": "Fold the paper in half diagonally to make a triangle",
11
+ "difficulty": 1,
12
+ "paper": {"width": 1.0, "height": 1.0},
13
+ "target_fold": {
14
+ "vertices_coords": [[0, 0], [1, 0], [1, 1], [0, 1]],
15
+ "edges_vertices": [[0, 1], [1, 2], [2, 3], [3, 0], [0, 2]],
16
+ "edges_assignment": ["B", "B", "B", "B", "V"],
17
+ "edges_foldAngle": [0, 0, 0, 0, 180],
18
+ "faces_vertices": [[0, 1, 2], [0, 2, 3]],
19
+ },
20
+ },
21
+ "half_fold": {
22
+ "name": "half_fold",
23
+ "description": "Fold the paper in half horizontally",
24
+ "difficulty": 1,
25
+ "paper": {"width": 1.0, "height": 1.0},
26
+ "target_fold": {
27
+ "vertices_coords": [
28
+ [0, 0], [1, 0], [1, 1], [0, 1], [0, 0.5], [1, 0.5],
29
+ ],
30
+ "edges_vertices": [
31
+ [0, 1], [1, 5], [5, 2], [2, 3], [3, 4], [4, 0],
32
+ [4, 5],
33
+ ],
34
+ "edges_assignment": ["B", "B", "B", "B", "B", "B", "V"],
35
+ "edges_foldAngle": [0, 0, 0, 0, 0, 0, 180],
36
+ "faces_vertices": [[0, 1, 5, 4], [4, 5, 2, 3]],
37
+ },
38
+ },
39
+ "quarter_fold": {
40
+ "name": "quarter_fold",
41
+ "description": "Fold the paper into quarters (two perpendicular folds)",
42
+ "difficulty": 2,
43
+ "paper": {"width": 1.0, "height": 1.0},
44
+ "target_fold": {
45
+ "vertices_coords": [
46
+ [0, 0], [0.5, 0], [1, 0],
47
+ [0, 0.5], [0.5, 0.5], [1, 0.5],
48
+ [0, 1], [0.5, 1], [1, 1],
49
+ ],
50
+ "edges_vertices": [
51
+ [0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
52
+ [1, 4], [4, 7],
53
+ [3, 4], [4, 5],
54
+ ],
55
+ "edges_assignment": [
56
+ "B", "B", "B", "B", "B", "B", "B", "B",
57
+ "V", "V", "V", "V",
58
+ ],
59
+ "edges_foldAngle": [
60
+ 0, 0, 0, 0, 0, 0, 0, 0,
61
+ 180, 180, 180, 180,
62
+ ],
63
+ "faces_vertices": [
64
+ [0, 1, 4, 3],
65
+ [1, 2, 5, 4],
66
+ [3, 4, 7, 6],
67
+ [4, 5, 8, 7],
68
+ ],
69
+ },
70
+ },
71
+ "letter_fold": {
72
+ "name": "letter_fold",
73
+ "description": "Tri-fold the paper like a letter (two parallel folds)",
74
+ "difficulty": 2,
75
+ "paper": {"width": 1.0, "height": 1.0},
76
+ "target_fold": {
77
+ "vertices_coords": [
78
+ [0, 0], [1, 0],
79
+ [0, 1/3], [1, 1/3],
80
+ [0, 2/3], [1, 2/3],
81
+ [0, 1], [1, 1],
82
+ ],
83
+ "edges_vertices": [
84
+ [0, 1], [1, 3], [3, 5], [5, 7], [7, 6], [6, 4], [4, 2], [2, 0],
85
+ [2, 3],
86
+ [4, 5],
87
+ ],
88
+ "edges_assignment": [
89
+ "B", "B", "B", "B", "B", "B", "B", "B",
90
+ "V", "M",
91
+ ],
92
+ "edges_foldAngle": [
93
+ 0, 0, 0, 0, 0, 0, 0, 0,
94
+ 180, -180,
95
+ ],
96
+ "faces_vertices": [
97
+ [0, 1, 3, 2],
98
+ [2, 3, 5, 4],
99
+ [4, 5, 7, 6],
100
+ ],
101
+ },
102
+ },
103
+ # --- Optigami patterns (from lib/patterns.ts) ---
104
+ "waterbomb": {
105
+ "name": "waterbomb",
106
+ "description": "Create a waterbomb base with valley folds on diagonals and mountain folds on midlines of a square sheet",
107
+ "difficulty": 3,
108
+ "paper": {"width": 2.0, "height": 2.0},
109
+ "target_fold": {
110
+ "vertices_coords": [
111
+ [-1, 1], [0, 1], [1, 1],
112
+ [-1, 0], [0, 0], [1, 0],
113
+ [-1, -1], [0, -1], [1, -1],
114
+ ],
115
+ "edges_vertices": [
116
+ [0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
117
+ [0, 4], [2, 4], [6, 4], [8, 4],
118
+ [1, 4], [3, 4], [5, 4], [7, 4],
119
+ ],
120
+ "edges_assignment": [
121
+ "B", "B", "B", "B", "B", "B", "B", "B",
122
+ "V", "V", "V", "V",
123
+ "M", "M", "M", "M",
124
+ ],
125
+ "edges_foldAngle": [
126
+ 0, 0, 0, 0, 0, 0, 0, 0,
127
+ 180, 180, 180, 180,
128
+ -180, -180, -180, -180,
129
+ ],
130
+ "faces_vertices": [
131
+ [0, 1, 4], [1, 2, 4],
132
+ [2, 5, 4], [5, 8, 4],
133
+ [8, 7, 4], [7, 6, 4],
134
+ [6, 3, 4], [3, 0, 4],
135
+ ],
136
+ },
137
+ },
138
+ "accordion": {
139
+ "name": "accordion",
140
+ "description": "Make an accordion (zig-zag) fold with alternating mountain and valley creases like a paper fan",
141
+ "difficulty": 2,
142
+ "paper": {"width": 2.0, "height": 2.0},
143
+ "target_fold": {
144
+ "vertices_coords": [
145
+ [-1, 1], [-0.5, 1], [0, 1], [0.5, 1], [1, 1],
146
+ [-1, -1], [-0.5, -1], [0, -1], [0.5, -1], [1, -1],
147
+ ],
148
+ "edges_vertices": [
149
+ [0, 1], [1, 2], [2, 3], [3, 4],
150
+ [5, 6], [6, 7], [7, 8], [8, 9],
151
+ [0, 5], [4, 9],
152
+ [1, 6], [2, 7], [3, 8],
153
+ ],
154
+ "edges_assignment": [
155
+ "B", "B", "B", "B",
156
+ "B", "B", "B", "B",
157
+ "B", "B",
158
+ "V", "M", "V",
159
+ ],
160
+ "edges_foldAngle": [
161
+ 0, 0, 0, 0,
162
+ 0, 0, 0, 0,
163
+ 0, 0,
164
+ 180, -180, 180,
165
+ ],
166
+ "faces_vertices": [
167
+ [0, 1, 6, 5], [1, 2, 7, 6],
168
+ [2, 3, 8, 7], [3, 4, 9, 8],
169
+ ],
170
+ },
171
+ },
172
+ "miura_ori": {
173
+ "name": "miura_ori",
174
+ "description": "Create a Miura-ori tessellation fold pattern on a 2x2 grid with offset zigzag vertices",
175
+ "difficulty": 3,
176
+ "paper": {"width": 2.0, "height": 2.0},
177
+ "target_fold": {
178
+ "vertices_coords": [
179
+ [-1, 1], [0, 1.2], [1, 1],
180
+ [-1, 0], [0, 0.2], [1, 0],
181
+ [-1, -1], [0, -0.8], [1, -1],
182
+ ],
183
+ "edges_vertices": [
184
+ [0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
185
+ [1, 4], [4, 7],
186
+ [3, 4], [4, 5],
187
+ ],
188
+ "edges_assignment": [
189
+ "B", "B", "B", "B", "B", "B", "B", "B",
190
+ "M", "M",
191
+ "V", "M",
192
+ ],
193
+ "edges_foldAngle": [
194
+ 0, 0, 0, 0, 0, 0, 0, 0,
195
+ -180, -180,
196
+ 180, -180,
197
+ ],
198
+ "faces_vertices": [
199
+ [0, 1, 4, 3], [1, 2, 5, 4],
200
+ [3, 4, 7, 6], [4, 5, 8, 7],
201
+ ],
202
+ },
203
+ },
204
+ }
205
+
206
+
207
+ def get_task(name: str | None = None) -> dict:
208
+ """Get a task by name. Defaults to 'triangle'."""
209
+ if name is None:
210
+ name = "triangle"
211
+ if name not in TASKS:
212
+ raise ValueError(f"Unknown task '{name}'. Available: {list(TASKS.keys())}")
213
+ return TASKS[name]
214
+
215
+
216
+ def list_tasks() -> list[str]:
217
+ """List all available task names."""
218
+ return list(TASKS.keys())
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ numpy>=1.24.0
3
+ openenv-core[core]>=0.2.1
4
+ pydantic>=2.0.0
5
+ scipy>=1.10
6
+ uvicorn>=0.23.0
start.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ # Start the OpenEnv API in the background on port 8000
3
+ python3 -m uvicorn origami_server.app:app --host 0.0.0.0 --port 8000 &
4
+
5
+ # Start Next.js frontend on port 7860 (foreground)
6
+ node server.js
training/__init__.py ADDED
File without changes
training/reward.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO reward functions for origami RL training.
2
+
3
+ Two reward functions (matching the 2048 pattern):
4
+ 1. valid_fold: Does the LLM output parse as valid FOLD JSON?
5
+ 2. shape_match: Simulate and compare to target shape.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ from origami_server.engine.fold_parser import validate_fold
15
+ from origami_server.engine.shape_match import compute_shape_match
16
+ from origami_server.engine.simulate import simulate
17
+ from origami_server.tasks import get_task
18
+
19
+
20
+ def extract_fold_json(response: str) -> dict | None:
21
+ """Extract FOLD JSON from LLM response text.
22
+
23
+ Looks for JSON between ```json ... ``` or raw JSON object.
24
+ """
25
+ # Try fenced code block first
26
+ match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", response, re.DOTALL)
27
+ if match:
28
+ try:
29
+ return json.loads(match.group(1))
30
+ except json.JSONDecodeError:
31
+ pass
32
+
33
+ # Try raw JSON object
34
+ match = re.search(r"\{[^{}]*\"vertices_coords\"[^{}]*\}", response, re.DOTALL)
35
+ if match:
36
+ try:
37
+ return json.loads(match.group(0))
38
+ except json.JSONDecodeError:
39
+ pass
40
+
41
+ # Try parsing the whole response
42
+ try:
43
+ data = json.loads(response.strip())
44
+ if isinstance(data, dict) and "vertices_coords" in data:
45
+ return data
46
+ except (json.JSONDecodeError, ValueError):
47
+ pass
48
+
49
+ return None
50
+
51
+
52
+ def valid_fold(completions: list, **kwargs: Any) -> list[float]:
53
+ """Reward 1: Does the LLM output parse as valid FOLD JSON?
54
+
55
+ +1.0 valid FOLD JSON with correct structure
56
+ -0.5 parseable JSON but invalid FOLD structure
57
+ -2.0 not parseable as JSON at all
58
+ """
59
+ scores = []
60
+ for completion in completions:
61
+ response = completion[0]["content"]
62
+ fold_data = extract_fold_json(response)
63
+
64
+ if fold_data is None:
65
+ scores.append(-2.0)
66
+ continue
67
+
68
+ is_valid, error = validate_fold(fold_data)
69
+ if is_valid:
70
+ scores.append(1.0)
71
+ else:
72
+ scores.append(-0.5)
73
+
74
+ return scores
75
+
76
+
77
+ def shape_match(
78
+ completions: list,
79
+ task_name: str = "triangle",
80
+ **kwargs: Any,
81
+ ) -> list[float]:
82
+ """Reward 2: Simulate the fold and compare to target shape.
83
+
84
+ Score = similarity * 20.0 (range: 0 to 20)
85
+ -1.0 if simulation fails/diverges
86
+ -2.0 if FOLD data is invalid
87
+ """
88
+ task = get_task(task_name)
89
+ target_fold = task["target_fold"]
90
+
91
+ # Pre-compute target positions
92
+ try:
93
+ target_result = simulate(target_fold, crease_percent=1.0)
94
+ target_positions = target_result.positions
95
+ except Exception:
96
+ return [0.0] * len(completions)
97
+
98
+ scores = []
99
+ for completion in completions:
100
+ response = completion[0]["content"]
101
+ fold_data = extract_fold_json(response)
102
+
103
+ if fold_data is None:
104
+ scores.append(-2.0)
105
+ continue
106
+
107
+ is_valid, error = validate_fold(fold_data)
108
+ if not is_valid:
109
+ scores.append(-1.0)
110
+ continue
111
+
112
+ try:
113
+ result = simulate(fold_data, crease_percent=1.0)
114
+ similarity = compute_shape_match(result.positions, target_positions)
115
+ scores.append(similarity * 20.0)
116
+ except Exception:
117
+ scores.append(-1.0)
118
+
119
+ return scores
training/train_grpo.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO training script for origami RL.
2
+
3
+ Follows the OpenEnv + Unsloth pattern:
4
+ - LLM generates FOLD JSON crease patterns
5
+ - Two reward functions: valid_fold + shape_match
6
+ - GRPOTrainer from TRL handles the RL loop
7
+
8
+ Usage (Colab):
9
+ python -m training.train_grpo --task triangle --max_steps 600
10
+ """
11
+
12
+ import argparse
13
+
14
+ PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
15
+ that, when folded, produces the target shape described below.
16
+
17
+ Target: {description}
18
+ Paper size: {width} x {height}
19
+
20
+ Output a JSON object with these exact fields:
21
+ - vertices_coords: [[x, y], ...] — 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)
22
+ - edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges
23
+ - edges_assignment: ["B"|"M"|"V", ...] — B=boundary, M=mountain fold, V=valley fold
24
+ - edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative like -180, V: positive like 180, B: 0)
25
+
26
+ Rules:
27
+ - Boundary edges (B) must outline the paper rectangle
28
+ - At least one fold crease (M or V) must exist
29
+ - Mountain fold angles are negative (-180 to 0)
30
+ - Valley fold angles are positive (0 to 180)
31
+ - All vertex indices in edges must be valid (0 to N-1)
32
+
33
+ Output ONLY the JSON object wrapped in ```json ... ``` markers."""
34
+
35
+
36
+ def build_prompt(task: dict) -> str:
37
+ return PROMPT_TEMPLATE.format(
38
+ description=task["description"],
39
+ width=task["paper"]["width"],
40
+ height=task["paper"]["height"],
41
+ )
42
+
43
+
44
+ def main():
45
+ parser = argparse.ArgumentParser(description="GRPO training for origami RL")
46
+ parser.add_argument("--task", default="triangle", help="Task name")
47
+ parser.add_argument("--max_steps", type=int, default=600)
48
+ parser.add_argument("--num_generations", type=int, default=4)
49
+ parser.add_argument("--model", default="unsloth/Qwen3-8B-unsloth-bnb-4bit")
50
+ parser.add_argument("--lr", type=float, default=2e-4)
51
+ args = parser.parse_args()
52
+
53
+ # --- These imports are heavy, only load when actually training ---
54
+ from datasets import Dataset
55
+ from trl import GRPOConfig, GRPOTrainer
56
+ from unsloth import FastLanguageModel
57
+
58
+ from origami_server.tasks import get_task
59
+ from training.reward import shape_match, valid_fold
60
+
61
+ task = get_task(args.task)
62
+ prompt_text = build_prompt(task)
63
+
64
+ # Build dataset (1000 copies of same prompt, like 2048)
65
+ dataset = Dataset.from_list(
66
+ [
67
+ {
68
+ "prompt": [{"role": "user", "content": prompt_text}],
69
+ "answer": 0,
70
+ }
71
+ ]
72
+ * 1000
73
+ )
74
+
75
+ # Load model with LoRA
76
+ model, tokenizer = FastLanguageModel.from_pretrained(
77
+ model_name=args.model,
78
+ load_in_4bit=True,
79
+ max_seq_length=2048,
80
+ )
81
+
82
+ model = FastLanguageModel.get_peft_model(
83
+ model,
84
+ r=8,
85
+ target_modules=[
86
+ "q_proj", "k_proj", "v_proj", "o_proj",
87
+ "gate_proj", "up_proj", "down_proj",
88
+ ],
89
+ lora_alpha=16,
90
+ use_gradient_checkpointing="unsloth",
91
+ )
92
+
93
+ # Wrap shape_match to inject task_name
94
+ def shape_match_reward(completions, **kwargs):
95
+ return shape_match(completions, task_name=args.task, **kwargs)
96
+
97
+ # GRPO config
98
+ training_args = GRPOConfig(
99
+ temperature=1.0,
100
+ learning_rate=args.lr,
101
+ weight_decay=0.001,
102
+ warmup_ratio=0.1,
103
+ lr_scheduler_type="linear",
104
+ optim="adamw_8bit",
105
+ logging_steps=1,
106
+ per_device_train_batch_size=1,
107
+ gradient_accumulation_steps=1,
108
+ num_generations=args.num_generations,
109
+ max_prompt_length=1024,
110
+ max_completion_length=1024,
111
+ max_steps=args.max_steps,
112
+ save_steps=100,
113
+ output_dir="outputs",
114
+ )
115
+
116
+ trainer = GRPOTrainer(
117
+ model=model,
118
+ processing_class=tokenizer,
119
+ reward_funcs=[valid_fold, shape_match_reward],
120
+ args=training_args,
121
+ train_dataset=dataset,
122
+ )
123
+
124
+ trainer.train()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()