ianalin123 Claude Sonnet 4.6 commited on
Commit
c44bdad
·
1 Parent(s): 39c6d23

feat: implement origami RL environment (Phase 1)

Browse files

Core environment built with 4-wave agent team:
- CreaseGraph + PaperState geometry engine (Shapely-based)
- Kawasaki, Maekawa, BLB verifiers with correct cyclic ordering
- 8 FOLD target files (levels 1-3, parallel folds only)
- Dense reward: progress (45%) + economy + validity theorems
- OrigamiEnvironment: code-as-policy + step modes, clone() for GRPO
- Prompt formatter: code-as-policy and step-level templates
- GRPO training script with dry-run mode

28 tests passing. Run: python train.py --dry_run

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

.gitignore CHANGED
@@ -28,3 +28,5 @@ __pycache__/
28
 
29
  # Reference repos (not pushed to HF)
30
  .reference/
 
 
 
28
 
29
  # Reference repos (not pushed to HF)
30
  .reference/
31
+ *.pyc
32
+ __pycache__/
env/__init__.py ADDED
File without changes
env/environment.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import copy
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from .paper_state import PaperState
8
+ from .rewards import compute_reward, compute_terminal_reward, load_target, target_crease_edges
9
+ from .prompts import (
10
+ code_as_policy_prompt,
11
+ step_level_prompt,
12
+ parse_fold_list,
13
+ parse_single_fold,
14
+ )
15
+ from .verifier import check_all_vertices
16
+
17
+
18
+ TARGETS_DIR = Path(__file__).parent / 'targets'
19
+
20
+
21
+ class OrigamiEnvironment:
22
+ """
23
+ OpenEnv-compatible origami crease pattern environment.
24
+
25
+ Supports two modes:
26
+ - code_as_policy: model outputs complete fold sequence, gets terminal reward
27
+ - step: model outputs one fold at a time, gets per-step reward
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ mode: str = 'code_as_policy', # 'code_as_policy' or 'step'
33
+ max_steps: int = 8,
34
+ targets_dir: Optional[str] = None,
35
+ ):
36
+ assert mode in ('code_as_policy', 'step'), f"Unknown mode: {mode}"
37
+ self.mode = mode
38
+ self.max_steps = max_steps
39
+ self.targets_dir = Path(targets_dir) if targets_dir else TARGETS_DIR
40
+
41
+ self.paper: Optional[PaperState] = None
42
+ self.target: Optional[dict] = None
43
+ self.target_name: Optional[str] = None
44
+ self.step_count: int = 0
45
+ self.last_reward: Optional[dict] = None
46
+
47
+ # Cache all available targets
48
+ self._targets = self._load_all_targets()
49
+
50
+ def _load_all_targets(self) -> dict[str, dict]:
51
+ targets = {}
52
+ for fold_file in self.targets_dir.glob('*.fold'):
53
+ with open(fold_file) as f:
54
+ targets[fold_file.stem] = json.load(f)
55
+ return targets
56
+
57
+ def available_targets(self) -> list[str]:
58
+ return sorted(self._targets.keys())
59
+
60
+ def reset(self, target_name: Optional[str] = None) -> dict:
61
+ """
62
+ Reset environment to start of a new episode.
63
+
64
+ Args:
65
+ target_name: name of target (stem of .fold file). If None, picks level-1 randomly.
66
+
67
+ Returns:
68
+ observation dict with 'prompt' key containing the LLM prompt string.
69
+ """
70
+ import random
71
+
72
+ if target_name:
73
+ assert target_name in self._targets, f"Unknown target: {target_name}"
74
+ self.target_name = target_name
75
+ else:
76
+ # Default to level-1 targets
77
+ level1 = [k for k, v in self._targets.items() if v.get('level', 1) == 1]
78
+ self.target_name = random.choice(level1 if level1 else list(self._targets.keys()))
79
+
80
+ self.target = self._targets[self.target_name]
81
+ self.paper = PaperState()
82
+ self.step_count = 0
83
+ self.last_reward = None
84
+
85
+ return self._get_observation()
86
+
87
+ def step(self, action) -> tuple[dict, dict, bool, dict]:
88
+ """
89
+ Execute an action.
90
+
91
+ In code_as_policy mode: action is a string (model completion with <folds> tags)
92
+ OR a list of fold dicts already parsed.
93
+ In step mode: action is a string (single fold JSON) or dict.
94
+
95
+ Returns:
96
+ (observation, reward, done, info)
97
+ """
98
+ if self.mode == 'code_as_policy':
99
+ return self._step_sequence(action)
100
+ else:
101
+ return self._step_single(action)
102
+
103
+ def _step_sequence(self, action) -> tuple[dict, dict, bool, dict]:
104
+ """Execute a complete fold sequence (code-as-policy mode)."""
105
+ # Parse action if it's a string
106
+ if isinstance(action, str):
107
+ try:
108
+ folds = parse_fold_list(action)
109
+ except ValueError as e:
110
+ bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
111
+ return self._get_observation(), bad_reward, True, self._info()
112
+ else:
113
+ folds = action # already a list of dicts
114
+
115
+ # Execute each fold sequentially
116
+ last_result = {'valid': True, 'anchored': True, 'new_vertices': [], 'errors': []}
117
+ for fold in folds:
118
+ try:
119
+ p1 = fold['from']
120
+ p2 = fold['to']
121
+ assignment = fold['assignment']
122
+ except (KeyError, TypeError) as e:
123
+ last_result = {'valid': False, 'anchored': False, 'new_vertices': [], 'errors': [str(e)]}
124
+ break
125
+
126
+ last_result = self.paper.add_crease(p1, p2, assignment)
127
+ self.step_count += 1
128
+ if not last_result['valid']:
129
+ break # stop at first invalid fold, partial credit
130
+
131
+ reward = compute_terminal_reward(self.paper, self.target)
132
+ self.last_reward = reward
133
+ return self._get_observation(), reward, True, self._info()
134
+
135
+ def _step_single(self, action) -> tuple[dict, dict, bool, dict]:
136
+ """Execute a single fold (step mode)."""
137
+ if isinstance(action, str):
138
+ try:
139
+ fold = parse_single_fold(action)
140
+ except ValueError as e:
141
+ bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
142
+ self.last_reward = bad_reward
143
+ done = self.step_count >= self.max_steps
144
+ return self._get_observation(), bad_reward, done, self._info()
145
+ else:
146
+ fold = action
147
+
148
+ try:
149
+ p1 = fold['from']
150
+ p2 = fold['to']
151
+ assignment = fold['assignment']
152
+ except (KeyError, TypeError) as e:
153
+ bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
154
+ self.last_reward = bad_reward
155
+ done = self.step_count >= self.max_steps
156
+ return self._get_observation(), bad_reward, done, self._info()
157
+
158
+ result = self.paper.add_crease(p1, p2, assignment)
159
+ self.step_count += 1
160
+
161
+ reward = compute_reward(self.paper, result, self.target)
162
+ self.last_reward = reward
163
+
164
+ done = (
165
+ self.step_count >= self.max_steps or
166
+ reward.get('completion', 0) > 0
167
+ )
168
+ return self._get_observation(), reward, done, self._info()
169
+
170
+ def _get_observation(self) -> dict:
171
+ """Returns observation dict with the LLM prompt and raw state."""
172
+ if self.mode == 'code_as_policy':
173
+ prompt = code_as_policy_prompt(self.target, max_folds=self.max_steps)
174
+ else:
175
+ prompt = step_level_prompt(
176
+ target=self.target,
177
+ paper_state=self.paper,
178
+ step=self.step_count,
179
+ max_steps=self.max_steps,
180
+ last_reward=self.last_reward,
181
+ )
182
+
183
+ return {
184
+ 'prompt': prompt,
185
+ 'target_name': self.target_name,
186
+ 'step': self.step_count,
187
+ 'paper_fold_json': self.paper.graph.edges if self.paper else {},
188
+ }
189
+
190
+ def _info(self) -> dict:
191
+ """Returns diagnostic info dict for logging."""
192
+ if self.paper is None:
193
+ return {}
194
+
195
+ interior = self.paper.graph.interior_vertices()
196
+ vertex_scores = check_all_vertices(self.paper.graph)
197
+
198
+ return {
199
+ 'local_foldability': (
200
+ vertex_scores['kawasaki'] == 1.0 and
201
+ vertex_scores['maekawa'] == 1.0
202
+ ),
203
+ 'blb_satisfied': vertex_scores['blb'] == 1.0,
204
+ 'global_foldability': 'not_checked', # NP-complete (Bern-Hayes 1996)
205
+ 'n_interior_vertices': len(interior),
206
+ 'n_creases': len(self.paper.graph.crease_edges()),
207
+ 'target_name': self.target_name,
208
+ }
209
+
210
+ def state(self) -> dict:
211
+ """Returns current environment state for logging/inspection."""
212
+ return {
213
+ 'paper': {
214
+ 'vertices': dict(self.paper.graph.vertices),
215
+ 'edges': {
216
+ k: v for k, v in self.paper.graph.edges.items()
217
+ if v[2] in ('M', 'V')
218
+ },
219
+ 'fold_history': self.paper.fold_history,
220
+ },
221
+ 'target': self.target_name,
222
+ 'step': self.step_count,
223
+ 'mode': self.mode,
224
+ }
225
+
226
+ def close(self):
227
+ """Cleanup."""
228
+ pass
229
+
230
+ def clone(self) -> 'OrigamiEnvironment':
231
+ """Return a deep copy for parallel evaluation (used in GRPO)."""
232
+ new_env = OrigamiEnvironment(
233
+ mode=self.mode,
234
+ max_steps=self.max_steps,
235
+ targets_dir=str(self.targets_dir),
236
+ )
237
+ if self.paper is not None:
238
+ new_env.paper = copy.deepcopy(self.paper)
239
+ new_env.target = self.target
240
+ new_env.target_name = self.target_name
241
+ new_env.step_count = self.step_count
242
+ new_env.last_reward = self.last_reward
243
+ return new_env
env/graph.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Optional
3
+
4
+ BOUNDARY_TOL = 1e-9
5
+ VERTEX_TOL = 1e-9
6
+
7
+
8
+ class CreaseGraph:
9
+ """
10
+ Planar graph representing an origami crease pattern on a unit square.
11
+
12
+ Vertices: points in [0,1]x[0,1], deduplicated by proximity.
13
+ Edges: segments between vertices, labeled M (mountain), V (valley), or B (boundary).
14
+ """
15
+
16
+ def __init__(self):
17
+ self.vertices: dict[int, tuple[float, float]] = {}
18
+ self.edges: dict[int, tuple[int, int, str]] = {}
19
+ self.vertex_edges: dict[int, list[int]] = {}
20
+ self._next_vertex_id: int = 0
21
+ self._next_edge_id: int = 0
22
+
23
+ corners = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
24
+ for x, y in corners:
25
+ vid = self._next_vertex_id
26
+ self.vertices[vid] = (x, y)
27
+ self.vertex_edges[vid] = []
28
+ self._next_vertex_id += 1
29
+
30
+ boundary_pairs = [(0, 1), (1, 2), (2, 3), (3, 0)]
31
+ for v1, v2 in boundary_pairs:
32
+ eid = self._next_edge_id
33
+ self.edges[eid] = (v1, v2, 'B')
34
+ self.vertex_edges[v1].append(eid)
35
+ self.vertex_edges[v2].append(eid)
36
+ self._next_edge_id += 1
37
+
38
+ def add_vertex(self, x: float, y: float) -> int:
39
+ for vid, (vx, vy) in self.vertices.items():
40
+ if abs(vx - x) < VERTEX_TOL and abs(vy - y) < VERTEX_TOL:
41
+ return vid
42
+ vid = self._next_vertex_id
43
+ self.vertices[vid] = (float(x), float(y))
44
+ self.vertex_edges[vid] = []
45
+ self._next_vertex_id += 1
46
+ return vid
47
+
48
+ def add_edge(self, v1_id: int, v2_id: int, assignment: str) -> int:
49
+ pair = frozenset((v1_id, v2_id))
50
+ for eid, (ev1, ev2, _) in self.edges.items():
51
+ if frozenset((ev1, ev2)) == pair:
52
+ return eid
53
+ eid = self._next_edge_id
54
+ self.edges[eid] = (v1_id, v2_id, assignment)
55
+ self.vertex_edges[v1_id].append(eid)
56
+ self.vertex_edges[v2_id].append(eid)
57
+ self._next_edge_id += 1
58
+ return eid
59
+
60
+ def get_cyclic_edges(self, vertex_id: int) -> list[int]:
61
+ vx, vy = self.vertices[vertex_id]
62
+ edge_ids = self.vertex_edges[vertex_id]
63
+
64
+ def angle_of_edge(eid: int) -> float:
65
+ ev1, ev2, _ = self.edges[eid]
66
+ other_id = ev2 if ev1 == vertex_id else ev1
67
+ ox, oy = self.vertices[other_id]
68
+ return float(np.arctan2(oy - vy, ox - vx))
69
+
70
+ return sorted(edge_ids, key=angle_of_edge)
71
+
72
+ def interior_vertices(self) -> list[int]:
73
+ result = []
74
+ for vid, (x, y) in self.vertices.items():
75
+ if (
76
+ x > BOUNDARY_TOL
77
+ and x < 1.0 - BOUNDARY_TOL
78
+ and y > BOUNDARY_TOL
79
+ and y < 1.0 - BOUNDARY_TOL
80
+ ):
81
+ result.append(vid)
82
+ return result
83
+
84
+ def split_edge(self, edge_id: int, new_vertex_id: int) -> tuple[int, int]:
85
+ ev1, ev2, assignment = self.edges[edge_id]
86
+
87
+ del self.edges[edge_id]
88
+ if edge_id in self.vertex_edges[ev1]:
89
+ self.vertex_edges[ev1].remove(edge_id)
90
+ if edge_id in self.vertex_edges[ev2]:
91
+ self.vertex_edges[ev2].remove(edge_id)
92
+
93
+ eid1 = self._next_edge_id
94
+ self.edges[eid1] = (ev1, new_vertex_id, assignment)
95
+ self.vertex_edges[ev1].append(eid1)
96
+ self.vertex_edges[new_vertex_id].append(eid1)
97
+ self._next_edge_id += 1
98
+
99
+ eid2 = self._next_edge_id
100
+ self.edges[eid2] = (new_vertex_id, ev2, assignment)
101
+ self.vertex_edges[new_vertex_id].append(eid2)
102
+ self.vertex_edges[ev2].append(eid2)
103
+ self._next_edge_id += 1
104
+
105
+ return (eid1, eid2)
106
+
107
+ def crease_edges(self) -> list[int]:
108
+ return [eid for eid, (_, _, a) in self.edges.items() if a in ('M', 'V')]
109
+
110
+ def boundary_midpoints(self) -> list[tuple[float, float]]:
111
+ midpoints = []
112
+ for eid, (v1, v2, assignment) in self.edges.items():
113
+ if assignment == 'B':
114
+ x1, y1 = self.vertices[v1]
115
+ x2, y2 = self.vertices[v2]
116
+ midpoints.append(((x1 + x2) / 2.0, (y1 + y2) / 2.0))
117
+ return midpoints
env/paper_state.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from shapely.geometry import LineString, Point, Polygon
3
+ from shapely.ops import unary_union
4
+ from typing import Optional
5
+ from .graph import CreaseGraph, VERTEX_TOL
6
+
7
+ UNIT_SQUARE_CORNERS = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
8
+
9
+ _UNIT_SQUARE = Polygon(UNIT_SQUARE_CORNERS)
10
+
11
+
12
+ class PaperState:
13
+ """
14
+ Represents the evolving crease pattern on a unit square [0,1]x[0,1].
15
+ Uses CreaseGraph for the underlying data structure.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.graph = CreaseGraph()
20
+ self.fold_history: list[dict] = []
21
+
22
+ def anchor_points(self) -> list[tuple[float, float]]:
23
+ points: dict[tuple[float, float], None] = {}
24
+ for corner in UNIT_SQUARE_CORNERS:
25
+ points[corner] = None
26
+ for vid, (x, y) in self.graph.vertices.items():
27
+ points[(float(x), float(y))] = None
28
+ return list(points.keys())
29
+
30
+ def _is_anchor(self, pt: tuple[float, float]) -> bool:
31
+ px, py = pt
32
+ for ax, ay in self.anchor_points():
33
+ if abs(ax - px) < VERTEX_TOL and abs(ay - py) < VERTEX_TOL:
34
+ return True
35
+ return False
36
+
37
+ def add_crease(self, p1: list, p2: list, assignment: str) -> dict:
38
+ errors: list[str] = []
39
+
40
+ if assignment not in ('M', 'V'):
41
+ return {
42
+ 'valid': False,
43
+ 'anchored': False,
44
+ 'new_vertices': [],
45
+ 'errors': ['invalid_assignment'],
46
+ }
47
+
48
+ p1 = (float(p1[0]), float(p1[1]))
49
+ p2 = (float(p2[0]), float(p2[1]))
50
+
51
+ anchored = self._is_anchor(p1) and self._is_anchor(p2)
52
+
53
+ seg_len = np.hypot(p2[0] - p1[0], p2[1] - p1[1])
54
+ if seg_len < VERTEX_TOL:
55
+ errors.append('zero_length')
56
+ return {'valid': False, 'anchored': anchored, 'new_vertices': [], 'errors': errors}
57
+
58
+ new_line = LineString([p1, p2])
59
+
60
+ if not _UNIT_SQUARE.contains(new_line) and not _UNIT_SQUARE.boundary.contains(new_line):
61
+ clipped = new_line.intersection(_UNIT_SQUARE)
62
+ if clipped.is_empty:
63
+ errors.append('outside_bounds')
64
+ return {'valid': False, 'anchored': anchored, 'new_vertices': [], 'errors': errors}
65
+
66
+ intersection_points: list[tuple[float, float]] = []
67
+
68
+ for eid, (ev1, ev2, _) in list(self.graph.edges.items()):
69
+ ex1, ey1 = self.graph.vertices[ev1]
70
+ ex2, ey2 = self.graph.vertices[ev2]
71
+ existing_line = LineString([(ex1, ey1), (ex2, ey2)])
72
+ inter = new_line.intersection(existing_line)
73
+
74
+ if inter.is_empty:
75
+ continue
76
+
77
+ if inter.geom_type == 'Point':
78
+ ix, iy = inter.x, inter.y
79
+ ep1 = (ex1, ey1)
80
+ ep2 = (ex2, ey2)
81
+ if (
82
+ abs(ix - ep1[0]) < VERTEX_TOL and abs(iy - ep1[1]) < VERTEX_TOL
83
+ or abs(ix - ep2[0]) < VERTEX_TOL and abs(iy - ep2[1]) < VERTEX_TOL
84
+ ):
85
+ continue
86
+ intersection_points.append((ix, iy))
87
+ # MultiPoint or LineString intersections (collinear) are skipped
88
+
89
+ new_vertex_coords: list[tuple[float, float]] = []
90
+ for ix, iy in intersection_points:
91
+ before = set(self.graph.vertices.keys())
92
+ vid = self.graph.add_vertex(ix, iy)
93
+ if vid not in before:
94
+ new_vertex_coords.append((ix, iy))
95
+
96
+ for eid in list(self.graph.edges.keys()):
97
+ if eid not in self.graph.edges:
98
+ continue
99
+ ev1, ev2, _ = self.graph.edges[eid]
100
+ ex1, ey1 = self.graph.vertices[ev1]
101
+ ex2, ey2 = self.graph.vertices[ev2]
102
+ seg = LineString([(ex1, ey1), (ex2, ey2)])
103
+ pt = Point(ix, iy)
104
+ if seg.distance(pt) < VERTEX_TOL:
105
+ if ev1 != vid and ev2 != vid:
106
+ self.graph.split_edge(eid, vid)
107
+
108
+ v1_id = self.graph.add_vertex(p1[0], p1[1])
109
+ v2_id = self.graph.add_vertex(p2[0], p2[1])
110
+
111
+ waypoints = [p1] + sorted(
112
+ intersection_points,
113
+ key=lambda pt: np.hypot(pt[0] - p1[0], pt[1] - p1[1]),
114
+ ) + [p2]
115
+
116
+ waypoint_ids = []
117
+ for wp in waypoints:
118
+ wid = self.graph.add_vertex(wp[0], wp[1])
119
+ waypoint_ids.append(wid)
120
+
121
+ for i in range(len(waypoint_ids) - 1):
122
+ wa = waypoint_ids[i]
123
+ wb = waypoint_ids[i + 1]
124
+ if wa != wb:
125
+ self.graph.add_edge(wa, wb, assignment)
126
+
127
+ record = {
128
+ 'p1': p1,
129
+ 'p2': p2,
130
+ 'assignment': assignment,
131
+ 'anchored': anchored,
132
+ 'new_vertices': new_vertex_coords,
133
+ }
134
+ self.fold_history.append(record)
135
+
136
+ return {
137
+ 'valid': True,
138
+ 'anchored': anchored,
139
+ 'new_vertices': new_vertex_coords,
140
+ 'errors': errors,
141
+ }
142
+
143
+ def crease_edges(self) -> list[dict]:
144
+ result = []
145
+ for eid in self.graph.crease_edges():
146
+ v1, v2, assignment = self.graph.edges[eid]
147
+ x1, y1 = self.graph.vertices[v1]
148
+ x2, y2 = self.graph.vertices[v2]
149
+ result.append({'v1': (x1, y1), 'v2': (x2, y2), 'assignment': assignment})
150
+ return result
env/prompts.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Optional
4
+
5
+ _CORNERS = {(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)}
6
+ _BOUNDARY_X = {0.0, 1.0}
7
+ _BOUNDARY_Y = {0.0, 1.0}
8
+
9
+
10
+ def _is_corner(x: float, y: float) -> bool:
11
+ return (round(x, 4), round(y, 4)) in _CORNERS
12
+
13
+
14
+ def _is_boundary(x: float, y: float) -> bool:
15
+ return x in _BOUNDARY_X or y in _BOUNDARY_Y
16
+
17
+
18
+ def format_target_for_prompt(target: dict) -> str:
19
+ vertices = target["vertices_coords"]
20
+ edges_v = target["edges_vertices"]
21
+ edges_a = target["edges_assignment"]
22
+
23
+ lines = []
24
+ for (v1, v2), assignment in zip(edges_v, edges_a):
25
+ if assignment not in ("M", "V"):
26
+ continue
27
+ x1, y1 = vertices[v1]
28
+ x2, y2 = vertices[v2]
29
+ label = "Mountain" if assignment == "M" else "Valley"
30
+ lines.append(
31
+ f"{label} fold: ({round(x1, 4)}, {round(y1, 4)}) -> ({round(x2, 4)}, {round(y2, 4)})"
32
+ )
33
+ return "\n".join(lines)
34
+
35
+
36
+ def format_anchor_points(paper_state) -> str:
37
+ corners = []
38
+ boundary_pts = []
39
+ intersections = []
40
+
41
+ for x, y in paper_state.anchor_points():
42
+ rx, ry = round(x, 4), round(y, 4)
43
+ if _is_corner(rx, ry):
44
+ corners.append((rx, ry))
45
+ elif _is_boundary(rx, ry):
46
+ boundary_pts.append((rx, ry))
47
+ else:
48
+ intersections.append((rx, ry))
49
+
50
+ def fmt_pts(pts: list[tuple[float, float]]) -> str:
51
+ return " ".join(f"({x},{y})" for x, y in pts)
52
+
53
+ lines = []
54
+ if corners:
55
+ lines.append(f" Corners: {fmt_pts(corners)}")
56
+ if boundary_pts:
57
+ lines.append(f" Boundary pts: {fmt_pts(boundary_pts)}")
58
+ if intersections:
59
+ lines.append(f" Intersections: {fmt_pts(intersections)}")
60
+
61
+ return "\n".join(lines)
62
+
63
+
64
+ def format_crease_history(paper_state) -> str:
65
+ history = paper_state.fold_history
66
+ if not history:
67
+ return "none"
68
+
69
+ lines = []
70
+ for i, fold in enumerate(history, 1):
71
+ p1, p2 = fold["p1"], fold["p2"]
72
+ assignment = fold["assignment"]
73
+ label = "Mountain" if assignment == "M" else "Valley"
74
+ x1, y1 = round(p1[0], 4), round(p1[1], 4)
75
+ x2, y2 = round(p2[0], 4), round(p2[1], 4)
76
+ lines.append(f" {i}. {label} fold: ({x1}, {y1}) -> ({x2}, {y2})")
77
+
78
+ return "\n".join(lines)
79
+
80
+
81
+ def format_reward_feedback(reward: Optional[dict]) -> str:
82
+ if not reward:
83
+ return "(no feedback yet)"
84
+
85
+ keys = ["kawasaki", "maekawa", "blb", "progress", "economy", "total"]
86
+ parts = []
87
+ for k in keys:
88
+ if k in reward:
89
+ parts.append(f"{k}={reward[k]:.2f}")
90
+
91
+ for k, v in reward.items():
92
+ if k not in keys:
93
+ parts.append(f"{k}={v:.2f}")
94
+
95
+ return " " + " ".join(parts)
96
+
97
+
98
+ def code_as_policy_prompt(target: dict, max_folds: int = 8) -> str:
99
+ formatted_target = format_target_for_prompt(target)
100
+ return f"""You are an origami designer. Generate a fold sequence for a unit square [0,1]x[0,1].
101
+
102
+ TARGET CREASE PATTERN:
103
+ {formatted_target}
104
+
105
+ RULES (must hold at every interior vertex):
106
+ - Kawasaki: alternating sector angles sum equally (each half = 180 degrees)
107
+ - Maekawa: |mountain_count - valley_count| = 2
108
+ - Big-Little-Big: folds bounding the smallest sector must have opposite types (one M, one V)
109
+
110
+ INITIAL ANCHOR POINTS (valid fold endpoints — new ones appear when creases intersect):
111
+ Corners: (0.0,0.0) (1.0,0.0) (1.0,1.0) (0.0,1.0)
112
+ Midpoints: (0.0,0.5) (0.5,0.0) (1.0,0.5) (0.5,1.0)
113
+ Note: new anchor points are created at crease intersections.
114
+
115
+ Output at most {max_folds} folds. Both endpoints must be valid anchor points.
116
+ Output ONLY the JSON list, wrapped in <folds> tags:
117
+
118
+ <folds>
119
+ [
120
+ {{"instruction": "Describe the fold in plain English", "from": [x1, y1], "to": [x2, y2], "assignment": "V"}},
121
+ {{"instruction": "...", "from": [x1, y1], "to": [x2, y2], "assignment": "M"}}
122
+ ]
123
+ </folds>"""
124
+
125
+
126
+ def step_level_prompt(
127
+ target: dict,
128
+ paper_state,
129
+ step: int,
130
+ max_steps: int,
131
+ last_reward: Optional[dict] = None,
132
+ ) -> str:
133
+ formatted_target = format_target_for_prompt(target)
134
+ formatted_history = format_crease_history(paper_state)
135
+ formatted_anchors = format_anchor_points(paper_state)
136
+ formatted_reward = format_reward_feedback(last_reward)
137
+
138
+ return f"""You are an origami designer building a crease pattern step by step.
139
+
140
+ TARGET:
141
+ {formatted_target}
142
+
143
+ CURRENT STATE (step {step} of {max_steps}):
144
+ Creases placed:
145
+ {formatted_history}
146
+
147
+ AVAILABLE ANCHOR POINTS:
148
+ {formatted_anchors}
149
+
150
+ LAST REWARD:
151
+ {formatted_reward}
152
+
153
+ Add the NEXT crease. Both endpoints must be listed anchor points above.
154
+ Output ONLY valid JSON (no extra text):
155
+ {{"instruction": "...", "from": [x1, y1], "to": [x2, y2], "assignment": "M" or "V"}}"""
156
+
157
+
158
+ def parse_fold_list(completion: str) -> list[dict]:
159
+ match = re.search(r"<folds>(.*?)</folds>", completion, re.IGNORECASE | re.DOTALL)
160
+ if not match:
161
+ raise ValueError("No <folds>...</folds> tags found in completion")
162
+
163
+ raw = match.group(1).strip()
164
+
165
+ try:
166
+ data = json.loads(raw)
167
+ except json.JSONDecodeError as e:
168
+ raise ValueError(f"Failed to parse JSON inside <folds> tags: {e}") from e
169
+
170
+ if not isinstance(data, list):
171
+ raise ValueError(f"Expected a JSON list inside <folds> tags, got {type(data).__name__}")
172
+
173
+ cleaned = []
174
+ for i, item in enumerate(data):
175
+ if not isinstance(item, dict):
176
+ raise ValueError(f"Fold {i} is not a dict: {item!r}")
177
+
178
+ for field in ("from", "to", "assignment"):
179
+ if field not in item:
180
+ raise ValueError(f"Fold {i} missing required field '{field}'")
181
+
182
+ from_pt = item["from"]
183
+ to_pt = item["to"]
184
+
185
+ if (
186
+ not isinstance(from_pt, list)
187
+ or len(from_pt) != 2
188
+ or not all(isinstance(v, (int, float)) for v in from_pt)
189
+ ):
190
+ raise ValueError(f"Fold {i} 'from' must be a list of 2 numbers, got {from_pt!r}")
191
+
192
+ if (
193
+ not isinstance(to_pt, list)
194
+ or len(to_pt) != 2
195
+ or not all(isinstance(v, (int, float)) for v in to_pt)
196
+ ):
197
+ raise ValueError(f"Fold {i} 'to' must be a list of 2 numbers, got {to_pt!r}")
198
+
199
+ if not isinstance(item["assignment"], str):
200
+ raise ValueError(f"Fold {i} 'assignment' must be a string")
201
+
202
+ cleaned.append(
203
+ {
204
+ "from": [float(from_pt[0]), float(from_pt[1])],
205
+ "to": [float(to_pt[0]), float(to_pt[1])],
206
+ "assignment": item["assignment"],
207
+ "instruction": item.get("instruction", ""),
208
+ }
209
+ )
210
+
211
+ return cleaned
212
+
213
+
214
+ def parse_single_fold(completion: str) -> dict:
215
+ start = completion.find("{")
216
+ end = completion.rfind("}")
217
+
218
+ if start == -1 or end == -1 or end <= start:
219
+ raise ValueError("No JSON object found in completion")
220
+
221
+ raw = completion[start : end + 1]
222
+
223
+ try:
224
+ data = json.loads(raw)
225
+ except json.JSONDecodeError as e:
226
+ raise ValueError(f"Failed to parse JSON from completion: {e}") from e
227
+
228
+ if not isinstance(data, dict):
229
+ raise ValueError(f"Expected a JSON object, got {type(data).__name__}")
230
+
231
+ for field in ("from", "to", "assignment"):
232
+ if field not in data:
233
+ raise ValueError(f"Missing required field '{field}' in fold JSON")
234
+
235
+ return data
env/rewards.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from .verifier import check_all_vertices, geometric_crease_coverage
3
+ from .paper_state import PaperState
4
+
5
+
6
+ def load_target(target_path: str) -> dict:
7
+ """Load a .fold target file and return it as a dict."""
8
+ with open(target_path) as f:
9
+ return json.load(f)
10
+
11
+
12
+ def target_crease_edges(target: dict) -> list[dict]:
13
+ """
14
+ Extract crease edges from a FOLD target dict as list of
15
+ {'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'} dicts.
16
+ """
17
+ verts = target['vertices_coords']
18
+ result = []
19
+ for i, (v1_idx, v2_idx) in enumerate(target['edges_vertices']):
20
+ assignment = target['edges_assignment'][i]
21
+ if assignment in ('M', 'V'):
22
+ result.append({
23
+ 'v1': tuple(verts[v1_idx]),
24
+ 'v2': tuple(verts[v2_idx]),
25
+ 'assignment': assignment,
26
+ })
27
+ return result
28
+
29
+
30
+ def compute_reward(
31
+ state: PaperState,
32
+ action_result: dict,
33
+ target: dict,
34
+ ) -> dict:
35
+ """
36
+ Compute the full reward dict for a fold action.
37
+
38
+ Args:
39
+ state: current PaperState AFTER the action was applied
40
+ action_result: {'valid': bool, 'anchored': bool, 'new_vertices': list, 'errors': list}
41
+ target: FOLD target dict
42
+
43
+ Returns dict with keys:
44
+ format, anchored, kawasaki, maekawa, blb, progress, economy, completion, efficiency, total
45
+ """
46
+ r = {}
47
+
48
+ # Gate 1: format — did the action parse and apply?
49
+ r['format'] = 1.0 if action_result.get('valid', False) else 0.0
50
+ if not r['format']:
51
+ r['total'] = -0.1
52
+ return r
53
+
54
+ # Gate 2: anchoring — were endpoints valid anchor points?
55
+ r['anchored'] = 1.0 if action_result.get('anchored', False) else 0.3
56
+
57
+ # Vertex-level validity checks (all interior vertices)
58
+ vertex_scores = check_all_vertices(state.graph)
59
+ r['kawasaki'] = vertex_scores['kawasaki']
60
+ r['maekawa'] = vertex_scores['maekawa']
61
+ r['blb'] = vertex_scores['blb']
62
+
63
+ # Geometric progress
64
+ t_edges = target_crease_edges(target)
65
+ coverage, economy = geometric_crease_coverage(state, t_edges)
66
+ r['progress'] = coverage
67
+ r['economy'] = economy
68
+
69
+ # Completion bonus: high coverage + all vertex conditions satisfied
70
+ all_valid = (r['kawasaki'] == 1.0 and r['maekawa'] == 1.0 and r['blb'] == 1.0)
71
+ r['completion'] = 10.0 if (r['progress'] > 0.9 and all_valid) else 0.0
72
+
73
+ # Step cost
74
+ r['efficiency'] = -0.01
75
+
76
+ # Weighted total
77
+ r['total'] = (
78
+ 0.05 * r['anchored'] +
79
+ 0.08 * r['kawasaki'] +
80
+ 0.07 * r['maekawa'] +
81
+ 0.05 * r['blb'] +
82
+ 0.45 * r['progress'] +
83
+ 0.10 * r['economy'] +
84
+ r['completion'] +
85
+ r['efficiency']
86
+ )
87
+ return r
88
+
89
+
90
+ def compute_terminal_reward(state: PaperState, target: dict) -> dict:
91
+ """Compute reward for the final state after a complete fold sequence."""
92
+ fake_result = {'valid': True, 'anchored': True, 'new_vertices': [], 'errors': []}
93
+ return compute_reward(state, fake_result, target)
env/targets/__init__.py ADDED
File without changes
env/targets/accordion_3h.fold ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.0, 0.25],
8
+ [1.0, 0.25],
9
+ [0.0, 0.5],
10
+ [1.0, 0.5],
11
+ [0.0, 0.75],
12
+ [1.0, 0.75]
13
+ ],
14
+ "edges_vertices": [
15
+ [0, 1],
16
+ [1, 5],
17
+ [5, 7],
18
+ [7, 9],
19
+ [9, 2],
20
+ [2, 3],
21
+ [3, 8],
22
+ [8, 6],
23
+ [6, 4],
24
+ [4, 0],
25
+ [4, 5],
26
+ [6, 7],
27
+ [8, 9]
28
+ ],
29
+ "edges_assignment": [
30
+ "B",
31
+ "B",
32
+ "B",
33
+ "B",
34
+ "B",
35
+ "B",
36
+ "B",
37
+ "B",
38
+ "B",
39
+ "B",
40
+ "V",
41
+ "M",
42
+ "V"
43
+ ],
44
+ "edges_foldAngle": [
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 0,
53
+ 0,
54
+ 0,
55
+ -180,
56
+ -180,
57
+ -180
58
+ ],
59
+ "faces_vertices": [
60
+ [0, 1, 5, 4],
61
+ [4, 5, 7, 6],
62
+ [6, 7, 9, 8],
63
+ [8, 9, 2, 3]
64
+ ],
65
+ "level": 3,
66
+ "description": "Three alternating horizontal folds at y=0.25 (valley), y=0.5 (mountain), y=0.75 (valley) forming an accordion"
67
+ }
env/targets/accordion_4h.fold ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.0, 0.2],
8
+ [1.0, 0.2],
9
+ [0.0, 0.4],
10
+ [1.0, 0.4],
11
+ [0.0, 0.6],
12
+ [1.0, 0.6],
13
+ [0.0, 0.8],
14
+ [1.0, 0.8]
15
+ ],
16
+ "edges_vertices": [
17
+ [0, 1],
18
+ [1, 5],
19
+ [5, 7],
20
+ [7, 9],
21
+ [9, 11],
22
+ [11, 2],
23
+ [2, 3],
24
+ [3, 10],
25
+ [10, 8],
26
+ [8, 6],
27
+ [6, 4],
28
+ [4, 0],
29
+ [4, 5],
30
+ [6, 7],
31
+ [8, 9],
32
+ [10, 11]
33
+ ],
34
+ "edges_assignment": [
35
+ "B",
36
+ "B",
37
+ "B",
38
+ "B",
39
+ "B",
40
+ "B",
41
+ "B",
42
+ "B",
43
+ "B",
44
+ "B",
45
+ "B",
46
+ "B",
47
+ "V",
48
+ "M",
49
+ "V",
50
+ "M"
51
+ ],
52
+ "edges_foldAngle": [
53
+ 0,
54
+ 0,
55
+ 0,
56
+ 0,
57
+ 0,
58
+ 0,
59
+ 0,
60
+ 0,
61
+ 0,
62
+ 0,
63
+ 0,
64
+ 0,
65
+ -180,
66
+ -180,
67
+ -180,
68
+ -180
69
+ ],
70
+ "faces_vertices": [
71
+ [0, 1, 5, 4],
72
+ [4, 5, 7, 6],
73
+ [6, 7, 9, 8],
74
+ [8, 9, 11, 10],
75
+ [10, 11, 2, 3]
76
+ ],
77
+ "level": 3,
78
+ "description": "Four alternating horizontal folds at y=0.2 (valley), y=0.4 (mountain), y=0.6 (valley), y=0.8 (mountain) forming an accordion"
79
+ }
env/targets/diagonal_anti.fold ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0]
7
+ ],
8
+ "edges_vertices": [
9
+ [0, 1],
10
+ [1, 2],
11
+ [2, 3],
12
+ [3, 0],
13
+ [1, 3]
14
+ ],
15
+ "edges_assignment": [
16
+ "B",
17
+ "B",
18
+ "B",
19
+ "B",
20
+ "M"
21
+ ],
22
+ "edges_foldAngle": [
23
+ 0,
24
+ 0,
25
+ 0,
26
+ 0,
27
+ -180
28
+ ],
29
+ "faces_vertices": [
30
+ [0, 1, 3],
31
+ [1, 2, 3]
32
+ ],
33
+ "level": 1,
34
+ "description": "One mountain fold along the anti-diagonal from (1,0) to (0,1)"
35
+ }
env/targets/diagonal_main.fold ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0]
7
+ ],
8
+ "edges_vertices": [
9
+ [0, 1],
10
+ [1, 2],
11
+ [2, 3],
12
+ [3, 0],
13
+ [0, 2]
14
+ ],
15
+ "edges_assignment": [
16
+ "B",
17
+ "B",
18
+ "B",
19
+ "B",
20
+ "V"
21
+ ],
22
+ "edges_foldAngle": [
23
+ 0,
24
+ 0,
25
+ 0,
26
+ 0,
27
+ -180
28
+ ],
29
+ "faces_vertices": [
30
+ [0, 1, 2],
31
+ [0, 2, 3]
32
+ ],
33
+ "level": 1,
34
+ "description": "One valley fold along the main diagonal from (0,0) to (1,1)"
35
+ }
env/targets/half_horizontal.fold ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.0, 0.5],
8
+ [1.0, 0.5]
9
+ ],
10
+ "edges_vertices": [
11
+ [0, 1],
12
+ [1, 5],
13
+ [5, 2],
14
+ [2, 3],
15
+ [3, 4],
16
+ [4, 0],
17
+ [4, 5]
18
+ ],
19
+ "edges_assignment": [
20
+ "B",
21
+ "B",
22
+ "B",
23
+ "B",
24
+ "B",
25
+ "B",
26
+ "V"
27
+ ],
28
+ "edges_foldAngle": [
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 0,
35
+ -180
36
+ ],
37
+ "faces_vertices": [
38
+ [0, 1, 5, 4],
39
+ [4, 5, 2, 3]
40
+ ],
41
+ "level": 1,
42
+ "description": "One valley fold along y=0.5, folding the paper in half horizontally"
43
+ }
env/targets/half_vertical.fold ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.5, 0.0],
8
+ [0.5, 1.0]
9
+ ],
10
+ "edges_vertices": [
11
+ [0, 4],
12
+ [4, 1],
13
+ [1, 2],
14
+ [2, 5],
15
+ [5, 3],
16
+ [3, 0],
17
+ [4, 5]
18
+ ],
19
+ "edges_assignment": [
20
+ "B",
21
+ "B",
22
+ "B",
23
+ "B",
24
+ "B",
25
+ "B",
26
+ "M"
27
+ ],
28
+ "edges_foldAngle": [
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 0,
35
+ -180
36
+ ],
37
+ "faces_vertices": [
38
+ [0, 4, 5, 3],
39
+ [4, 1, 2, 5]
40
+ ],
41
+ "level": 1,
42
+ "description": "One mountain fold along x=0.5, folding the paper in half vertically"
43
+ }
env/targets/thirds_h.fold ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.0, 0.3333333333333333],
8
+ [1.0, 0.3333333333333333],
9
+ [0.0, 0.6666666666666666],
10
+ [1.0, 0.6666666666666666]
11
+ ],
12
+ "edges_vertices": [
13
+ [0, 1],
14
+ [1, 5],
15
+ [5, 7],
16
+ [7, 2],
17
+ [2, 3],
18
+ [3, 6],
19
+ [6, 4],
20
+ [4, 0],
21
+ [4, 5],
22
+ [6, 7]
23
+ ],
24
+ "edges_assignment": [
25
+ "B",
26
+ "B",
27
+ "B",
28
+ "B",
29
+ "B",
30
+ "B",
31
+ "B",
32
+ "B",
33
+ "V",
34
+ "V"
35
+ ],
36
+ "edges_foldAngle": [
37
+ 0,
38
+ 0,
39
+ 0,
40
+ 0,
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ -180,
46
+ -180
47
+ ],
48
+ "faces_vertices": [
49
+ [0, 1, 5, 4],
50
+ [4, 5, 7, 6],
51
+ [6, 7, 2, 3]
52
+ ],
53
+ "level": 2,
54
+ "description": "Two parallel valley folds at y=1/3 and y=2/3, dividing the paper into horizontal thirds"
55
+ }
env/targets/thirds_v.fold ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vertices_coords": [
3
+ [0.0, 0.0],
4
+ [1.0, 0.0],
5
+ [1.0, 1.0],
6
+ [0.0, 1.0],
7
+ [0.3333333333333333, 0.0],
8
+ [0.6666666666666666, 0.0],
9
+ [0.3333333333333333, 1.0],
10
+ [0.6666666666666666, 1.0]
11
+ ],
12
+ "edges_vertices": [
13
+ [0, 4],
14
+ [4, 5],
15
+ [5, 1],
16
+ [1, 2],
17
+ [2, 7],
18
+ [7, 6],
19
+ [6, 3],
20
+ [3, 0],
21
+ [4, 6],
22
+ [5, 7]
23
+ ],
24
+ "edges_assignment": [
25
+ "B",
26
+ "B",
27
+ "B",
28
+ "B",
29
+ "B",
30
+ "B",
31
+ "B",
32
+ "B",
33
+ "M",
34
+ "M"
35
+ ],
36
+ "edges_foldAngle": [
37
+ 0,
38
+ 0,
39
+ 0,
40
+ 0,
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ -180,
46
+ -180
47
+ ],
48
+ "faces_vertices": [
49
+ [0, 4, 6, 3],
50
+ [4, 5, 7, 6],
51
+ [5, 1, 2, 7]
52
+ ],
53
+ "level": 2,
54
+ "description": "Two parallel mountain folds at x=1/3 and x=2/3, dividing the paper into vertical thirds"
55
+ }
env/targets/validator.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validates all .fold target files against origami theorems.
3
+ Run directly: python -m env.targets.validator
4
+ """
5
+ import json
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ from ..graph import CreaseGraph
11
+ from ..verifier import check_kawasaki_at_vertex, check_maekawa_at_vertex, check_blb_at_vertex
12
+
13
+
14
+ def build_graph_from_fold(fold_data: dict) -> CreaseGraph:
15
+ """
16
+ Reconstruct a CreaseGraph from a FOLD JSON dict.
17
+ Used to validate target files.
18
+ """
19
+ graph = CreaseGraph()
20
+
21
+ verts = fold_data['vertices_coords']
22
+ edges = fold_data['edges_vertices']
23
+ assignments = fold_data['edges_assignment']
24
+
25
+ # Map file vertex indices to graph vertex IDs
26
+ vert_map = {}
27
+ for i, (x, y) in enumerate(verts):
28
+ vid = graph.add_vertex(float(x), float(y))
29
+ vert_map[i] = vid
30
+
31
+ # Add edges (boundary edges from init may already exist, add_edge handles dedup)
32
+ for i, (v1_idx, v2_idx) in enumerate(edges):
33
+ v1_id = vert_map[v1_idx]
34
+ v2_id = vert_map[v2_idx]
35
+ assignment = assignments[i]
36
+ graph.add_edge(v1_id, v2_id, assignment)
37
+
38
+ return graph
39
+
40
+
41
+ def validate_target(fold_path: str) -> dict:
42
+ """
43
+ Validate a single .fold target file.
44
+ Returns {'file': str, 'valid': bool, 'issues': list[str], 'interior_vertices': int}
45
+ """
46
+ with open(fold_path) as f:
47
+ fold_data = json.load(f)
48
+
49
+ issues = []
50
+
51
+ # Basic structure checks
52
+ required = ['vertices_coords', 'edges_vertices', 'edges_assignment', 'edges_foldAngle']
53
+ for field in required:
54
+ if field not in fold_data:
55
+ issues.append(f"Missing field: {field}")
56
+
57
+ if issues:
58
+ return {'file': os.path.basename(fold_path), 'valid': False, 'issues': issues, 'interior_vertices': -1}
59
+
60
+ n_edges = len(fold_data['edges_vertices'])
61
+ if len(fold_data['edges_assignment']) != n_edges:
62
+ issues.append("edges_assignment length mismatch")
63
+ if len(fold_data['edges_foldAngle']) != n_edges:
64
+ issues.append("edges_foldAngle length mismatch")
65
+
66
+ # Build graph and check theorems
67
+ graph = build_graph_from_fold(fold_data)
68
+ interior = graph.interior_vertices()
69
+
70
+ for v_id in interior:
71
+ ok, alt_sum = check_kawasaki_at_vertex(v_id, graph)
72
+ if not ok:
73
+ issues.append(f"Kawasaki violated at vertex {v_id} (alt_sum={alt_sum:.6f})")
74
+
75
+ if not check_maekawa_at_vertex(v_id, graph):
76
+ issues.append(f"Maekawa violated at vertex {v_id}")
77
+
78
+ blb_violations = check_blb_at_vertex(v_id, graph)
79
+ if blb_violations:
80
+ issues.append(f"BLB violated at vertex {v_id}: {blb_violations}")
81
+
82
+ return {
83
+ 'file': os.path.basename(fold_path),
84
+ 'valid': len(issues) == 0,
85
+ 'issues': issues,
86
+ 'interior_vertices': len(interior),
87
+ }
88
+
89
+
90
+ def validate_all(targets_dir: str = None) -> bool:
91
+ """Validate all .fold files in the targets directory. Returns True if all pass."""
92
+ if targets_dir is None:
93
+ targets_dir = Path(__file__).parent
94
+
95
+ all_pass = True
96
+ fold_files = sorted(Path(targets_dir).glob('*.fold'))
97
+
98
+ if not fold_files:
99
+ print("No .fold files found")
100
+ return False
101
+
102
+ for fold_path in fold_files:
103
+ result = validate_target(str(fold_path))
104
+ status = "OK" if result['valid'] else "FAIL"
105
+ n_interior = result['interior_vertices']
106
+ print(f" [{status}] {result['file']} — {n_interior} interior vertices")
107
+ if result['issues']:
108
+ for issue in result['issues']:
109
+ print(f" ! {issue}")
110
+ if not result['valid']:
111
+ all_pass = False
112
+
113
+ return all_pass
114
+
115
+
116
+ if __name__ == '__main__':
117
+ print("Validating targets...")
118
+ ok = validate_all()
119
+ sys.exit(0 if ok else 1)
env/targets/validator_check.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, sys, os
2
+
3
+ targets_dir = "/Users/ianalin/Desktop/optigami/env/targets"
4
+ for fname in os.listdir(targets_dir):
5
+ if not fname.endswith(".fold"):
6
+ continue
7
+ with open(os.path.join(targets_dir, fname)) as f:
8
+ d = json.load(f)
9
+ n_v = len(d["vertices_coords"])
10
+ n_e = len(d["edges_vertices"])
11
+ assert len(d["edges_assignment"]) == n_e, f"{fname}: assignment length mismatch"
12
+ assert len(d["edges_foldAngle"]) == n_e, f"{fname}: foldAngle length mismatch"
13
+ for e in d["edges_vertices"]:
14
+ assert e[0] < n_v and e[1] < n_v, f"{fname}: edge references invalid vertex"
15
+ for face in d["faces_vertices"]:
16
+ for vi in face:
17
+ assert vi < n_v, f"{fname}: face references invalid vertex"
18
+ creases = [i for i,a in enumerate(d["edges_assignment"]) if a in ('M','V')]
19
+ print(f"{fname}: {n_v} vertices, {n_e} edges, {len(creases)} creases, level={d.get('level','?')} OK")
env/verifier.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .graph import CreaseGraph
3
+ from .paper_state import PaperState
4
+
5
+
6
+ def _compute_sector_angles(vertex_id: int, graph: CreaseGraph) -> list[float]:
7
+ """Compute consecutive sector angles (CCW) at a vertex from its cyclic edges."""
8
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
9
+ n = len(cyclic_edges)
10
+ vx, vy = graph.vertices[vertex_id]
11
+
12
+ angles = []
13
+ for eid in cyclic_edges:
14
+ ev1, ev2, _ = graph.edges[eid]
15
+ other_id = ev2 if ev1 == vertex_id else ev1
16
+ ox, oy = graph.vertices[other_id]
17
+ angles.append(np.arctan2(oy - vy, ox - vx))
18
+
19
+ sectors = []
20
+ for i in range(n):
21
+ diff = angles[(i + 1) % n] - angles[i]
22
+ if diff < 0:
23
+ diff += 2 * np.pi
24
+ if diff > 2 * np.pi:
25
+ diff -= 2 * np.pi
26
+ sectors.append(diff)
27
+
28
+ return sectors
29
+
30
+
31
+ def check_kawasaki_at_vertex(vertex_id: int, graph: CreaseGraph) -> tuple[bool, float]:
32
+ """
33
+ Checks Kawasaki-Justin theorem at a single vertex.
34
+
35
+ Kawasaki: at an interior vertex with 2n creases, the alternating sum
36
+ of consecutive sector angles = 0.
37
+ Equivalently: sum(odd-indexed sectors) == sum(even-indexed sectors) == π.
38
+
39
+ Returns (satisfied: bool, |alternating_sum|: float).
40
+ Returns (True, 0.0) for vertices with degree < 4 (not an interior fold vertex yet).
41
+ Returns (False, inf) for odd-degree vertices (impossible for flat folds).
42
+ """
43
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
44
+ n = len(cyclic_edges)
45
+
46
+ if n % 2 != 0:
47
+ return (False, float('inf'))
48
+
49
+ if n < 4:
50
+ return (True, 0.0)
51
+
52
+ sectors = _compute_sector_angles(vertex_id, graph)
53
+ alt_sum = sum(s * ((-1) ** i) for i, s in enumerate(sectors))
54
+ return (abs(alt_sum) < 1e-9, abs(alt_sum))
55
+
56
+
57
+ def check_maekawa_at_vertex(vertex_id: int, graph: CreaseGraph) -> bool:
58
+ """
59
+ Checks Maekawa-Justin theorem at a single vertex.
60
+
61
+ Maekawa: |M - V| == 2 where M, V are counts of mountain/valley fold edges
62
+ at the vertex. BOUNDARY edges ('B') are NOT counted.
63
+
64
+ Returns True if satisfied or if vertex has fewer than 4 fold edges (not yet active).
65
+ """
66
+ edge_ids = graph.vertex_edges[vertex_id]
67
+ fold_edges = [
68
+ eid for eid in edge_ids
69
+ if graph.edges[eid][2] in ('M', 'V')
70
+ ]
71
+
72
+ if len(fold_edges) < 4:
73
+ return True
74
+
75
+ m_count = sum(1 for eid in fold_edges if graph.edges[eid][2] == 'M')
76
+ v_count = sum(1 for eid in fold_edges if graph.edges[eid][2] == 'V')
77
+ return abs(m_count - v_count) == 2
78
+
79
+
80
+ def check_blb_at_vertex(vertex_id: int, graph: CreaseGraph) -> list[tuple[int, int]]:
81
+ """
82
+ Checks Big-Little-Big lemma at a single vertex.
83
+
84
+ BLB: if sector angle i is a strict local minimum (smaller than both neighbors),
85
+ the fold edges bounding that sector must have OPPOSITE MV assignments.
86
+
87
+ Returns list of (edge_a_id, edge_b_id) pairs where BLB is violated.
88
+ Empty list = no violations.
89
+ """
90
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
91
+ n = len(cyclic_edges)
92
+
93
+ if n < 4:
94
+ return []
95
+
96
+ sectors = _compute_sector_angles(vertex_id, graph)
97
+ violations = []
98
+
99
+ for i in range(n):
100
+ prev_sector = sectors[(i - 1) % n]
101
+ next_sector = sectors[(i + 1) % n]
102
+
103
+ if sectors[i] < prev_sector and sectors[i] < next_sector:
104
+ edge_a = cyclic_edges[i]
105
+ edge_b = cyclic_edges[(i + 1) % n]
106
+
107
+ assign_a = graph.edges[edge_a][2]
108
+ assign_b = graph.edges[edge_b][2]
109
+
110
+ if assign_a in ('M', 'V') and assign_b in ('M', 'V'):
111
+ if assign_a == assign_b:
112
+ violations.append((edge_a, edge_b))
113
+
114
+ return violations
115
+
116
+
117
+ def _angle_diff(a1: float, a2: float) -> float:
118
+ """Minimum angle difference between two directed lines (considering 180° symmetry)."""
119
+ diff = abs(a1 - a2) % np.pi
120
+ return min(diff, np.pi - diff)
121
+
122
+
123
+ def geometric_crease_coverage(
124
+ state: PaperState,
125
+ target_edges: list[dict],
126
+ tol_pos: float = 0.05,
127
+ tol_angle_deg: float = 5.0,
128
+ ) -> tuple[float, float]:
129
+ """
130
+ Computes how well the current crease pattern matches the target.
131
+
132
+ Args:
133
+ target_edges: list of {'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'}
134
+
135
+ Returns:
136
+ (coverage, economy)
137
+ coverage: fraction of target creases matched [0, 1]
138
+ economy: penalty for excess creases [0, 1], 1.0 = no excess
139
+ """
140
+ current_edges = state.crease_edges()
141
+ tol_angle_rad = np.deg2rad(tol_angle_deg)
142
+
143
+ matched = 0
144
+ for target in target_edges:
145
+ tx1, ty1 = target['v1']
146
+ tx2, ty2 = target['v2']
147
+ t_mid = ((tx1 + tx2) / 2.0, (ty1 + ty2) / 2.0)
148
+ t_angle = np.arctan2(ty2 - ty1, tx2 - tx1)
149
+
150
+ for current in current_edges:
151
+ cx1, cy1 = current['v1']
152
+ cx2, cy2 = current['v2']
153
+ c_mid = ((cx1 + cx2) / 2.0, (cy1 + cy2) / 2.0)
154
+ c_angle = np.arctan2(cy2 - cy1, cx2 - cx1)
155
+
156
+ mid_dist = np.hypot(c_mid[0] - t_mid[0], c_mid[1] - t_mid[1])
157
+ angle_distance = _angle_diff(c_angle, t_angle)
158
+
159
+ if mid_dist <= tol_pos and angle_distance <= tol_angle_rad:
160
+ matched += 1
161
+ break
162
+
163
+ coverage = matched / max(len(target_edges), 1)
164
+ n_excess = max(0, len(current_edges) - len(target_edges))
165
+ economy = max(0.0, 1.0 - n_excess / max(len(target_edges), 1))
166
+ return (coverage, economy)
167
+
168
+
169
+ def check_all_vertices(graph: CreaseGraph) -> dict:
170
+ """
171
+ Run all vertex-level checks on every interior vertex.
172
+
173
+ Returns dict with:
174
+ 'kawasaki': float # fraction of interior vertices passing Kawasaki [0,1]
175
+ 'maekawa': float # fraction passing Maekawa [0,1]
176
+ 'blb': float # fraction with no BLB violations [0,1]
177
+ 'n_interior': int # number of interior vertices checked
178
+ 'per_vertex': list[dict] # per-vertex details
179
+ """
180
+ interior = graph.interior_vertices()
181
+
182
+ if not interior:
183
+ return {
184
+ 'kawasaki': 1.0,
185
+ 'maekawa': 1.0,
186
+ 'blb': 1.0,
187
+ 'n_interior': 0,
188
+ 'per_vertex': [],
189
+ }
190
+
191
+ per_vertex = []
192
+ kaw_pass = 0
193
+ mae_pass = 0
194
+ blb_pass = 0
195
+
196
+ for vid in interior:
197
+ kaw_ok, kaw_val = check_kawasaki_at_vertex(vid, graph)
198
+ mae_ok = check_maekawa_at_vertex(vid, graph)
199
+ blb_violations = check_blb_at_vertex(vid, graph)
200
+ blb_ok = len(blb_violations) == 0
201
+
202
+ kaw_pass += int(kaw_ok)
203
+ mae_pass += int(mae_ok)
204
+ blb_pass += int(blb_ok)
205
+
206
+ per_vertex.append({
207
+ 'vertex_id': vid,
208
+ 'kawasaki_ok': kaw_ok,
209
+ 'kawasaki_error': kaw_val,
210
+ 'maekawa_ok': mae_ok,
211
+ 'blb_violations': blb_violations,
212
+ })
213
+
214
+ n = len(interior)
215
+ return {
216
+ 'kawasaki': kaw_pass / n,
217
+ 'maekawa': mae_pass / n,
218
+ 'blb': blb_pass / n,
219
+ 'n_interior': n,
220
+ 'per_vertex': per_vertex,
221
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ shapely>=2.0.0
2
+ numpy>=1.24.0
3
+ pytest>=7.0.0
tests/__init__.py ADDED
File without changes
tests/test_graph.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+ from env.graph import CreaseGraph, VERTEX_TOL
4
+
5
+
6
+ def test_init_boundary():
7
+ g = CreaseGraph()
8
+ assert len(g.vertices) == 4
9
+ assert len(g.edges) == 4
10
+ for eid, (v1, v2, assignment) in g.edges.items():
11
+ assert assignment == 'B'
12
+ assert g.interior_vertices() == []
13
+
14
+
15
+ def test_add_vertex_dedup():
16
+ g = CreaseGraph()
17
+ id1 = g.add_vertex(0.5, 0.5)
18
+ id2 = g.add_vertex(0.5, 0.5)
19
+ assert id1 == id2
20
+
21
+
22
+ def test_add_vertex_dedup_near():
23
+ g = CreaseGraph()
24
+ id1 = g.add_vertex(0.5, 0.5)
25
+ id2 = g.add_vertex(0.5 + VERTEX_TOL * 0.5, 0.5)
26
+ assert id1 == id2
27
+
28
+
29
+ def test_cyclic_order():
30
+ g = CreaseGraph()
31
+ center_id = g.add_vertex(0.5, 0.5)
32
+
33
+ right_id = g.add_vertex(0.8, 0.5) # 0 degrees
34
+ top_id = g.add_vertex(0.5, 0.8) # 90 degrees
35
+ left_id = g.add_vertex(0.2, 0.5) # 180 degrees
36
+ bottom_id = g.add_vertex(0.5, 0.2) # 270 degrees / -90 degrees
37
+
38
+ e_right = g.add_edge(center_id, right_id, 'M')
39
+ e_top = g.add_edge(center_id, top_id, 'M')
40
+ e_left = g.add_edge(center_id, left_id, 'M')
41
+ e_bottom = g.add_edge(center_id, bottom_id, 'M')
42
+
43
+ cyclic = g.get_cyclic_edges(center_id)
44
+ # Sorted by angle ascending: right(0), top(90), left(180), bottom(-90 → 270)
45
+ # arctan2 for bottom gives -pi/2 which sorts before 0 in ascending order
46
+ # So actual ascending order: bottom(-pi/2), right(0), top(pi/2), left(pi)
47
+ assert len(cyclic) == 4
48
+
49
+ def edge_angle(eid):
50
+ ev1, ev2, _ = g.edges[eid]
51
+ other_id = ev2 if ev1 == center_id else ev1
52
+ ox, oy = g.vertices[other_id]
53
+ cx, cy = g.vertices[center_id]
54
+ return float(np.arctan2(oy - cy, ox - cx))
55
+
56
+ angles = [edge_angle(eid) for eid in cyclic]
57
+ assert angles == sorted(angles), "Edges should be sorted by ascending angle"
58
+
59
+ assert e_right in cyclic
60
+ assert e_top in cyclic
61
+ assert e_left in cyclic
62
+ assert e_bottom in cyclic
63
+
64
+ # Verify specific order: bottom < right < top < left in angle space
65
+ pos = {eid: i for i, eid in enumerate(cyclic)}
66
+ assert pos[e_bottom] < pos[e_right] < pos[e_top] < pos[e_left]
67
+
68
+
69
+ def test_interior_vertices_empty():
70
+ g = CreaseGraph()
71
+ assert g.interior_vertices() == []
72
+
73
+
74
+ def test_interior_vertices_with_crease_intersection():
75
+ g = CreaseGraph()
76
+ center_id = g.add_vertex(0.5, 0.5)
77
+ assert center_id in g.interior_vertices()
78
+
79
+
80
+ def test_split_edge():
81
+ g = CreaseGraph()
82
+ # Find the bottom boundary edge (0,0)-(1,0) which is edge 0: v0-v1
83
+ original_edge_id = None
84
+ for eid, (v1, v2, assignment) in g.edges.items():
85
+ x1, y1 = g.vertices[v1]
86
+ x2, y2 = g.vertices[v2]
87
+ if {(x1, y1), (x2, y2)} == {(0.0, 0.0), (1.0, 0.0)}:
88
+ original_edge_id = eid
89
+ original_v1 = v1
90
+ original_v2 = v2
91
+ break
92
+
93
+ assert original_edge_id is not None
94
+
95
+ mid_id = g.add_vertex(0.5, 0.0)
96
+ eid1, eid2 = g.split_edge(original_edge_id, mid_id)
97
+
98
+ assert original_edge_id not in g.edges
99
+
100
+ assert eid1 in g.edges
101
+ assert eid2 in g.edges
102
+
103
+ _, _, a1 = g.edges[eid1]
104
+ _, _, a2 = g.edges[eid2]
105
+ assert a1 == 'B'
106
+ assert a2 == 'B'
107
+
108
+ def edge_vertex_set(eid):
109
+ v1, v2, _ = g.edges[eid]
110
+ return {v1, v2}
111
+
112
+ assert mid_id in edge_vertex_set(eid1)
113
+ assert mid_id in edge_vertex_set(eid2)
114
+ assert original_v1 in edge_vertex_set(eid1) or original_v1 in edge_vertex_set(eid2)
115
+ assert original_v2 in edge_vertex_set(eid1) or original_v2 in edge_vertex_set(eid2)
tests/test_paper_state.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from env.paper_state import PaperState, UNIT_SQUARE_CORNERS
3
+ from env.graph import VERTEX_TOL
4
+
5
+
6
+ def test_single_crease_no_interior_vertices():
7
+ paper = PaperState()
8
+ result = paper.add_crease([0.0, 0.5], [1.0, 0.5], 'V')
9
+ assert result['valid'] is True
10
+ interior = paper.graph.interior_vertices()
11
+ assert interior == [], f"Expected no interior vertices, got {interior}"
12
+
13
+
14
+ def test_anchor_points_initial():
15
+ paper = PaperState()
16
+ anchors = paper.anchor_points()
17
+ for corner in UNIT_SQUARE_CORNERS:
18
+ assert any(
19
+ abs(ax - corner[0]) < VERTEX_TOL and abs(ay - corner[1]) < VERTEX_TOL
20
+ for ax, ay in anchors
21
+ ), f"Corner {corner} not found in anchor_points"
22
+
23
+
24
+ def test_anchor_points_grow():
25
+ paper = PaperState()
26
+ result = paper.add_crease([0.0, 0.5], [1.0, 0.5], 'V')
27
+ assert result['valid'] is True
28
+
29
+ anchors = paper.anchor_points()
30
+
31
+ def has_point(px, py):
32
+ return any(abs(ax - px) < VERTEX_TOL and abs(ay - py) < VERTEX_TOL for ax, ay in anchors)
33
+
34
+ assert has_point(0.0, 0.5), "(0, 0.5) should be in anchor_points after crease"
35
+ assert has_point(1.0, 0.5), "(1, 0.5) should be in anchor_points after crease"
36
+
37
+
38
+ def test_invalid_assignment():
39
+ paper = PaperState()
40
+ result = paper.add_crease([0.0, 0.5], [1.0, 0.5], 'X')
41
+ assert result['valid'] is False
42
+ assert 'invalid_assignment' in result['errors']
43
+
44
+
45
+ def test_fold_history():
46
+ paper = PaperState()
47
+ paper.add_crease([0.0, 0.5], [1.0, 0.5], 'M')
48
+ assert len(paper.fold_history) == 1
49
+
50
+
51
+ def test_unanchored_returns_false_anchored():
52
+ paper = PaperState()
53
+ result = paper.add_crease([0.3, 0.3], [0.7, 0.7], 'M')
54
+ assert result['anchored'] is False
55
+
56
+
57
+ def test_crease_edges_returned():
58
+ paper = PaperState()
59
+ paper.add_crease([0.0, 0.5], [1.0, 0.5], 'M')
60
+ edges = paper.crease_edges()
61
+ assert len(edges) >= 1
62
+ for e in edges:
63
+ assert e['assignment'] in ('M', 'V')
64
+ assert 'v1' in e
65
+ assert 'v2' in e
66
+
67
+
68
+ def test_two_intersecting_creases():
69
+ paper = PaperState()
70
+ r1 = paper.add_crease([0.0, 0.5], [1.0, 0.5], 'M')
71
+ r2 = paper.add_crease([0.5, 0.0], [0.5, 1.0], 'V')
72
+ assert r1['valid'] is True
73
+ assert r2['valid'] is True
74
+ interior = paper.graph.interior_vertices()
75
+ assert len(interior) >= 1
76
+ coords = [paper.graph.vertices[vid] for vid in interior]
77
+ assert any(abs(x - 0.5) < VERTEX_TOL and abs(y - 0.5) < VERTEX_TOL for x, y in coords)
tests/test_verifier.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import numpy as np
3
+ from env.graph import CreaseGraph
4
+ from env.paper_state import PaperState
5
+ from env.verifier import (
6
+ check_kawasaki_at_vertex,
7
+ check_maekawa_at_vertex,
8
+ check_blb_at_vertex,
9
+ geometric_crease_coverage,
10
+ check_all_vertices,
11
+ )
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Helpers
16
+ # ---------------------------------------------------------------------------
17
+
18
+ def make_cross_graph(center_coords=(0.5, 0.5), assignment='M') -> tuple[CreaseGraph, int]:
19
+ """
20
+ Degree-4 vertex at center with 4 spokes pointing N/S/E/W.
21
+ All spokes have the given assignment.
22
+ """
23
+ g = CreaseGraph()
24
+ cx, cy = center_coords
25
+ vid = g.add_vertex(cx, cy)
26
+
27
+ neighbors = [
28
+ (0.0, cy), # left (180°)
29
+ (1.0, cy), # right (0°)
30
+ (cx, 0.0), # down (-90°)
31
+ (cx, 1.0), # up (90°)
32
+ ]
33
+ for nx, ny in neighbors:
34
+ nid = g.add_vertex(nx, ny)
35
+ g.add_edge(vid, nid, assignment)
36
+
37
+ return g, vid
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Kawasaki tests
42
+ # ---------------------------------------------------------------------------
43
+
44
+ def test_kawasaki_no_interior_vertices():
45
+ paper = PaperState()
46
+ paper.add_crease([0, 0.5], [1, 0.5], 'V')
47
+ assert paper.graph.interior_vertices() == []
48
+ result = check_all_vertices(paper.graph)
49
+ assert result['kawasaki'] == 1.0
50
+ assert result['n_interior'] == 0
51
+
52
+
53
+ def test_kawasaki_valid_degree4_vertex():
54
+ """Equal 90° sectors → alternating sum = 0 → Kawasaki satisfied."""
55
+ g, vid = make_cross_graph()
56
+ ok, err = check_kawasaki_at_vertex(vid, g)
57
+ assert ok == True
58
+ assert err == pytest.approx(0.0, abs=1e-9)
59
+
60
+
61
+ def test_kawasaki_invalid_vertex():
62
+ """
63
+ Manually construct a degree-4 vertex whose sectors are 60°,120°,80°,100°.
64
+ Alternating sum = 60 - 120 + 80 - 100 = -80° ≠ 0 → should fail.
65
+ """
66
+ g = CreaseGraph()
67
+ cx, cy = 0.5, 0.5
68
+ vid = g.add_vertex(cx, cy)
69
+
70
+ # Place neighbours at specific angles so sectors are exactly as desired.
71
+ # Sectors are measured CCW between consecutive rays.
72
+ # We choose ray angles (from center) in ascending arctan2 order:
73
+ # a0 = 0°
74
+ # a1 = 60° (sector0 = 60°)
75
+ # a2 = 180° (sector1 = 120°)
76
+ # a3 = 260° = -100° (sector2 = 80°)
77
+ # sector3 (wraparound to a0) = 360° - 260° = 100°
78
+ # alt_sum = 60 - 120 + 80 - 100 = -80° → |alt_sum| ≈ 1.396 rad
79
+ r = 0.3
80
+ angles_deg = [0.0, 60.0, 180.0, 260.0]
81
+ for deg in angles_deg:
82
+ rad = np.deg2rad(deg)
83
+ nx = cx + r * np.cos(rad)
84
+ ny = cy + r * np.sin(rad)
85
+ nid = g.add_vertex(nx, ny)
86
+ g.add_edge(vid, nid, 'M')
87
+
88
+ ok, err = check_kawasaki_at_vertex(vid, g)
89
+ assert ok == False
90
+ expected_err = abs(np.deg2rad(60 - 120 + 80 - 100))
91
+ assert err == pytest.approx(expected_err, abs=1e-6)
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Maekawa tests
96
+ # ---------------------------------------------------------------------------
97
+
98
+ def test_maekawa_excludes_boundary():
99
+ """
100
+ Boundary edges at a vertex should NOT count toward M/V tally.
101
+ A corner vertex has only boundary edges; Maekawa should return True
102
+ (fewer than 4 fold edges → vacuously satisfied).
103
+ """
104
+ g = CreaseGraph()
105
+ corner_id = 0 # vertex (0,0)
106
+ assert check_maekawa_at_vertex(corner_id, g) is True
107
+
108
+
109
+ def test_maekawa_valid():
110
+ """3 M + 1 V → |3-1| = 2 → True."""
111
+ g = CreaseGraph()
112
+ cx, cy = 0.5, 0.5
113
+ vid = g.add_vertex(cx, cy)
114
+
115
+ r = 0.3
116
+ angles_deg = [0.0, 90.0, 180.0, 270.0]
117
+ assignments = ['M', 'M', 'M', 'V']
118
+ for deg, asgn in zip(angles_deg, assignments):
119
+ rad = np.deg2rad(deg)
120
+ nid = g.add_vertex(cx + r * np.cos(rad), cy + r * np.sin(rad))
121
+ g.add_edge(vid, nid, asgn)
122
+
123
+ assert check_maekawa_at_vertex(vid, g) is True
124
+
125
+
126
+ def test_maekawa_invalid():
127
+ """2 M + 2 V → |2-2| = 0 → False."""
128
+ g = CreaseGraph()
129
+ cx, cy = 0.5, 0.5
130
+ vid = g.add_vertex(cx, cy)
131
+
132
+ r = 0.3
133
+ angles_deg = [0.0, 90.0, 180.0, 270.0]
134
+ assignments = ['M', 'M', 'V', 'V']
135
+ for deg, asgn in zip(angles_deg, assignments):
136
+ rad = np.deg2rad(deg)
137
+ nid = g.add_vertex(cx + r * np.cos(rad), cy + r * np.sin(rad))
138
+ g.add_edge(vid, nid, asgn)
139
+
140
+ assert check_maekawa_at_vertex(vid, g) is False
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # BLB tests
145
+ # ---------------------------------------------------------------------------
146
+
147
+ def test_blb_no_violations_equal_sectors():
148
+ """Equal 90° sectors → no strict local minimum → BLB returns []."""
149
+ g, vid = make_cross_graph()
150
+ violations = check_blb_at_vertex(vid, g)
151
+ assert violations == []
152
+
153
+
154
+ def test_blb_violation_detected():
155
+ """
156
+ Create a vertex with a strict local-minimum sector whose bounding edges
157
+ share the same MV assignment → BLB violation.
158
+
159
+ Use angles 0°, 10°, 180°, 270° so sector[0]=10° is the strict local min
160
+ relative to sector[3] (90°) and sector[1] (170°). The two bounding edges
161
+ are at 0° and 10°; assign both 'M' → violation.
162
+ """
163
+ g = CreaseGraph()
164
+ cx, cy = 0.5, 0.5
165
+ vid = g.add_vertex(cx, cy)
166
+
167
+ r = 0.3
168
+ # angles ascending (arctan2 order): 0°, 10°, 180°, 270° (= -90°)
169
+ # sorted arctan2: -90°, 0°, 10°, 180°
170
+ # sectors: 90°, 10°, 170°, 90° (sum=360°)
171
+ # sector at index 1 (between 0° and 10°) = 10° is strict local min (90 > 10 < 170)
172
+ angles_deg = [0.0, 10.0, 180.0, 270.0]
173
+ edge_ids = []
174
+ for deg in angles_deg:
175
+ rad = np.deg2rad(deg)
176
+ nid = g.add_vertex(cx + r * np.cos(rad), cy + r * np.sin(rad))
177
+ eid = g.add_edge(vid, nid, 'M')
178
+ edge_ids.append(eid)
179
+
180
+ violations = check_blb_at_vertex(vid, g)
181
+ assert len(violations) > 0
182
+
183
+
184
+ def test_blb_no_violation_when_opposite_assignments():
185
+ """
186
+ Same geometry as above but with opposite assignments on the two edges
187
+ bounding the small sector → no BLB violation.
188
+ """
189
+ g = CreaseGraph()
190
+ cx, cy = 0.5, 0.5
191
+ vid = g.add_vertex(cx, cy)
192
+
193
+ r = 0.3
194
+ angles_deg = [0.0, 10.0, 180.0, 270.0]
195
+ # sorted arctan2: -90°(270°), 0°, 10°, 180°
196
+ # small sector is between 0° and 10° (index 1 and 2 in sorted order)
197
+ # assign them opposite assignments
198
+ assignments_by_angle = {
199
+ 0.0: 'M',
200
+ 10.0: 'V',
201
+ 180.0: 'M',
202
+ 270.0: 'V',
203
+ }
204
+ for deg in angles_deg:
205
+ rad = np.deg2rad(deg)
206
+ nid = g.add_vertex(cx + r * np.cos(rad), cy + r * np.sin(rad))
207
+ g.add_edge(vid, nid, assignments_by_angle[deg])
208
+
209
+ violations = check_blb_at_vertex(vid, g)
210
+ assert violations == []
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Coverage tests
215
+ # ---------------------------------------------------------------------------
216
+
217
+ def test_coverage_exact_match():
218
+ """Add exact crease matching target → coverage = 1.0, economy = 1.0."""
219
+ paper = PaperState()
220
+ paper.add_crease([0.0, 0.5], [1.0, 0.5], 'M')
221
+
222
+ target = [{'v1': (0.0, 0.5), 'v2': (1.0, 0.5), 'assignment': 'M'}]
223
+ coverage, economy = geometric_crease_coverage(paper, target)
224
+ assert coverage == pytest.approx(1.0)
225
+ assert economy == pytest.approx(1.0)
226
+
227
+
228
+ def test_coverage_no_match():
229
+ """No creases added → coverage = 0.0."""
230
+ paper = PaperState()
231
+ target = [{'v1': (0.0, 0.5), 'v2': (1.0, 0.5), 'assignment': 'M'}]
232
+ coverage, economy = geometric_crease_coverage(paper, target)
233
+ assert coverage == pytest.approx(0.0)
234
+
235
+
236
+ def test_coverage_excess_penalty():
237
+ """
238
+ Target has 1 crease. Add 3 non-intersecting creases, one matching target.
239
+ coverage = 1.0, economy = 1 - 2/1 → clamped to 0.0 (economy < 1.0).
240
+ Uses non-intersecting extras to avoid PaperState edge splitting the target crease.
241
+ """
242
+ paper = PaperState()
243
+ paper.add_crease([0.0, 0.5], [1.0, 0.5], 'M') # matches target (midpoint 0.5,0.5)
244
+ paper.add_crease([0.0, 0.3], [0.5, 0.3], 'V') # extra, no intersection
245
+ paper.add_crease([0.0, 0.7], [0.5, 0.7], 'V') # extra, no intersection
246
+
247
+ target = [{'v1': (0.0, 0.5), 'v2': (1.0, 0.5), 'assignment': 'M'}]
248
+ coverage, economy = geometric_crease_coverage(paper, target)
249
+ assert coverage == pytest.approx(1.0)
250
+ assert economy < 1.0
251
+
252
+
253
+ # ---------------------------------------------------------------------------
254
+ # check_all_vertices vacuous test
255
+ # ---------------------------------------------------------------------------
256
+
257
+ def test_check_all_vertices_vacuous():
258
+ """Single horizontal crease → no interior vertices → all scores = 1.0."""
259
+ paper = PaperState()
260
+ paper.add_crease([0.0, 0.5], [1.0, 0.5], 'V')
261
+ result = check_all_vertices(paper.graph)
262
+ assert result['kawasaki'] == 1.0
263
+ assert result['maekawa'] == 1.0
264
+ assert result['blb'] == 1.0
265
+ assert result['n_interior'] == 0
266
+ assert result['per_vertex'] == []
train.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OrigamiRL — GRPO Training Script
3
+ Code-as-policy: model generates complete fold sequence, gets terminal reward.
4
+
5
+ Usage:
6
+ python train.py
7
+ python train.py --model unsloth/Qwen2.5-7B-Instruct --epochs 3 --output origami-grpo
8
+ """
9
+ import argparse
10
+ import json
11
+ import copy
12
+ import random
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--model', default='unsloth/Qwen2.5-7B-Instruct')
20
+ parser.add_argument('--max_seq_length', type=int, default=2048)
21
+ parser.add_argument('--epochs', type=int, default=3)
22
+ parser.add_argument('--batch_size', type=int, default=2)
23
+ parser.add_argument('--grad_accum', type=int, default=4)
24
+ parser.add_argument('--lr', type=float, default=5e-6)
25
+ parser.add_argument('--n_generations', type=int, default=8)
26
+ parser.add_argument('--max_folds', type=int, default=8)
27
+ parser.add_argument('--output', default='origami-grpo')
28
+ parser.add_argument('--level', type=int, default=1, help='Target difficulty level (1-3)')
29
+ parser.add_argument('--dry_run', action='store_true', help='Test reward function without training')
30
+ return parser.parse_args()
31
+
32
+
33
+ def build_dataset(env, level: int = 1, max_folds: int = 8) -> list[dict]:
34
+ """
35
+ Build a training dataset of prompts from available targets.
36
+ Each item: {'prompt': str, 'target_name': str}
37
+ Repeats each target multiple times to give enough training steps.
38
+ """
39
+ all_names = env.available_targets()
40
+
41
+ # Filter by level; fall back to all targets if none match
42
+ level_names = [
43
+ name for name in all_names
44
+ if env._targets[name].get('level', 1) == level
45
+ ]
46
+ if not level_names:
47
+ level_names = all_names
48
+
49
+ items = []
50
+ for name in level_names:
51
+ obs = env.reset(target_name=name)
52
+ prompt = obs['prompt']
53
+ items.append({'prompt': prompt, 'target_name': name})
54
+
55
+ # Repeat each target 10x; ensure at least 50 examples
56
+ repeat = max(10, (50 + len(items) - 1) // len(items))
57
+ items = items * repeat
58
+
59
+ random.shuffle(items)
60
+ return items
61
+
62
+
63
+ def make_reward_fn(env_template, max_folds: int):
64
+ """
65
+ Returns a reward function compatible with trl GRPOTrainer.
66
+
67
+ Signature: reward_fn(completions, prompts=None, **kwargs) -> list[float]
68
+
69
+ For each completion:
70
+ 1. Clone the environment (fresh paper state)
71
+ 2. Reset to the target embedded in the prompt (use target_name from kwargs if available)
72
+ 3. Execute the completion as a fold sequence
73
+ 4. Return the total reward
74
+ """
75
+ def reward_fn(completions, prompts=None, **kwargs):
76
+ rewards = []
77
+ target_names = kwargs.get('target_names', [None] * len(completions))
78
+
79
+ for completion, target_name in zip(completions, target_names):
80
+ try:
81
+ env = env_template.clone()
82
+ env.reset(target_name=target_name)
83
+ _, reward_dict, _, _ = env.step(completion)
84
+ rewards.append(float(reward_dict['total']))
85
+ except Exception:
86
+ rewards.append(-0.1)
87
+
88
+ return rewards
89
+
90
+ return reward_fn
91
+
92
+
93
+ def make_detailed_reward_fns(env_template, max_folds: int) -> list:
94
+ """
95
+ Returns a list of reward functions, one per reward component.
96
+ Used for detailed W&B logging of each component separately.
97
+ Components: kawasaki, maekawa, blb, progress, economy, completion
98
+ """
99
+ components = ['kawasaki', 'maekawa', 'blb', 'progress', 'economy', 'completion']
100
+
101
+ def make_component_fn(component: str):
102
+ def component_fn(completions, prompts=None, **kwargs):
103
+ rewards = []
104
+ target_names = kwargs.get('target_names', [None] * len(completions))
105
+
106
+ for completion, target_name in zip(completions, target_names):
107
+ try:
108
+ env = env_template.clone()
109
+ env.reset(target_name=target_name)
110
+ _, reward_dict, _, _ = env.step(completion)
111
+ rewards.append(float(reward_dict.get(component, 0.0)))
112
+ except Exception:
113
+ rewards.append(0.0)
114
+
115
+ return rewards
116
+
117
+ component_fn.__name__ = f'reward_{component}'
118
+ return component_fn
119
+
120
+ return [make_component_fn(c) for c in components]
121
+
122
+
123
+ def main():
124
+ args = parse_args()
125
+
126
+ # Import here to allow dry_run without GPU
127
+ from env.environment import OrigamiEnvironment
128
+
129
+ env = OrigamiEnvironment(mode='code_as_policy', max_steps=args.max_folds)
130
+
131
+ # Build dataset
132
+ dataset_items = build_dataset(env, level=args.level, max_folds=args.max_folds)
133
+ print(f"Dataset: {len(dataset_items)} examples from level-{args.level} targets")
134
+ print(f"Targets: {env.available_targets()}")
135
+
136
+ # Dry run: test reward function without loading model
137
+ if args.dry_run:
138
+ reward_fn = make_reward_fn(env, args.max_folds)
139
+ test_completions = [
140
+ '<folds>[{"instruction": "Valley fold along horizontal center", "from": [0, 0.5], "to": [1, 0.5], "assignment": "V"}]</folds>',
141
+ '<folds>[{"instruction": "Invalid fold", "from": [0.3, 0.3], "to": [0.7, 0.7], "assignment": "V"}]</folds>',
142
+ 'this is not valid JSON at all',
143
+ ]
144
+ target_names = [dataset_items[0]['target_name']] * 3
145
+ rewards = reward_fn(test_completions, target_names=target_names)
146
+ print(f"\nDry run rewards: {rewards}")
147
+ print("Dry run complete — reward function works.")
148
+ return
149
+
150
+ # Load model via unsloth
151
+ try:
152
+ from unsloth import FastLanguageModel
153
+ except ImportError:
154
+ print("ERROR: unsloth not installed. Run: pip install unsloth")
155
+ print("Or run with --dry_run to test the reward function without a model.")
156
+ return
157
+
158
+ model, tokenizer = FastLanguageModel.from_pretrained(
159
+ model_name=args.model,
160
+ max_seq_length=args.max_seq_length,
161
+ load_in_4bit=True,
162
+ )
163
+
164
+ model = FastLanguageModel.get_peft_model(
165
+ model,
166
+ r=32,
167
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
168
+ "gate_proj", "up_proj", "down_proj"],
169
+ lora_alpha=32,
170
+ lora_dropout=0,
171
+ use_gradient_checkpointing="unsloth",
172
+ )
173
+
174
+ # Convert dataset to HuggingFace Dataset format
175
+ from datasets import Dataset
176
+
177
+ # GRPOTrainer expects 'prompt' column and optionally others.
178
+ # We embed target_name in the dataset so the reward fn can use it.
179
+ hf_dataset = Dataset.from_list(dataset_items)
180
+
181
+ # Build main reward function
182
+ reward_fn = make_reward_fn(env, args.max_folds)
183
+
184
+ from trl import GRPOConfig, GRPOTrainer
185
+
186
+ config = GRPOConfig(
187
+ output_dir=args.output,
188
+ num_train_epochs=args.epochs,
189
+ per_device_train_batch_size=args.batch_size,
190
+ gradient_accumulation_steps=args.grad_accum,
191
+ learning_rate=args.lr,
192
+ max_completion_length=512,
193
+ num_generations=args.n_generations,
194
+ temperature=1.0,
195
+ logging_steps=1,
196
+ report_to="wandb",
197
+ run_name="origami-grpo",
198
+ )
199
+
200
+ # GRPOTrainer passes all dataset columns as kwargs to reward_funcs.
201
+ # The 'target_name' column arrives as a list (one per completion in the batch).
202
+ def wrapped_reward_fn(completions, target_name=None, **kwargs):
203
+ """Wrapper that extracts target_name from batch columns."""
204
+ target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)
205
+ return reward_fn(completions, target_names=target_names)
206
+
207
+ trainer = GRPOTrainer(
208
+ model=model,
209
+ config=config,
210
+ train_dataset=hf_dataset,
211
+ reward_funcs=[wrapped_reward_fn],
212
+ tokenizer=tokenizer,
213
+ )
214
+
215
+ print(f"\nStarting GRPO training...")
216
+ print(f" Model: {args.model}")
217
+ print(f" Level: {args.level} targets")
218
+ print(f" Epochs: {args.epochs}")
219
+ print(f" Generations per prompt: {args.n_generations}")
220
+ print(f" Output: {args.output}/")
221
+
222
+ trainer.train()
223
+
224
+ # Save
225
+ model.save_pretrained(args.output)
226
+ tokenizer.save_pretrained(args.output)
227
+ print(f"\nModel saved to {args.output}/")
228
+
229
+
230
+ if __name__ == '__main__':
231
+ main()