ianalin123 commited on
Commit
8484129
·
1 Parent(s): d82e085

feat(v2): add step_reward.py — per-step Kawasaki/Maekawa/coverage reward

Browse files
Files changed (1) hide show
  1. origami_server/engine/step_reward.py +392 -0
origami_server/engine/step_reward.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-step reward computation for V2 multi-step origami episodes.
2
+
3
+ Combines verifier (Kawasaki/Maekawa/BLB) and coverage-based reward from optigami.
4
+ """
5
+
6
+ import numpy as np
7
+ from .graph import CreaseGraph
8
+ from .paper_state import PaperState
9
+
10
+
11
+ def _compute_sector_angles(vertex_id: int, graph: CreaseGraph) -> list[float]:
12
+ """Compute consecutive sector angles (CCW) at a vertex from its cyclic edges."""
13
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
14
+ n = len(cyclic_edges)
15
+ vx, vy = graph.vertices[vertex_id]
16
+
17
+ angles = []
18
+ for eid in cyclic_edges:
19
+ ev1, ev2, _ = graph.edges[eid]
20
+ other_id = ev2 if ev1 == vertex_id else ev1
21
+ ox, oy = graph.vertices[other_id]
22
+ angles.append(np.arctan2(oy - vy, ox - vx))
23
+
24
+ sectors = []
25
+ for i in range(n):
26
+ diff = angles[(i + 1) % n] - angles[i]
27
+ if diff < 0:
28
+ diff += 2 * np.pi
29
+ if diff > 2 * np.pi:
30
+ diff -= 2 * np.pi
31
+ sectors.append(diff)
32
+
33
+ return sectors
34
+
35
+
36
+ def check_kawasaki_at_vertex(vertex_id: int, graph: CreaseGraph) -> tuple[bool, float]:
37
+ """
38
+ Checks Kawasaki-Justin theorem at a single vertex.
39
+
40
+ Kawasaki: at an interior vertex with 2n creases, the alternating sum
41
+ of consecutive sector angles = 0.
42
+ Equivalently: sum(odd-indexed sectors) == sum(even-indexed sectors) == π.
43
+
44
+ Returns (satisfied: bool, |alternating_sum|: float).
45
+ Returns (True, 0.0) for vertices with degree < 4 (not an interior fold vertex yet).
46
+ Returns (False, inf) for odd-degree vertices (impossible for flat folds).
47
+ """
48
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
49
+ n = len(cyclic_edges)
50
+
51
+ if n % 2 != 0:
52
+ return (False, float('inf'))
53
+
54
+ if n < 4:
55
+ return (True, 0.0)
56
+
57
+ sectors = _compute_sector_angles(vertex_id, graph)
58
+ alt_sum = sum(s * ((-1) ** i) for i, s in enumerate(sectors))
59
+ return (abs(alt_sum) < 1e-9, abs(alt_sum))
60
+
61
+
62
+ def check_maekawa_at_vertex(vertex_id: int, graph: CreaseGraph) -> bool:
63
+ """
64
+ Checks Maekawa-Justin theorem at a single vertex.
65
+
66
+ Maekawa: |M - V| == 2 where M, V are counts of mountain/valley fold edges
67
+ at the vertex. BOUNDARY edges ('B') are NOT counted.
68
+
69
+ Returns True if satisfied or if vertex has fewer than 4 fold edges (not yet active).
70
+ """
71
+ edge_ids = graph.vertex_edges[vertex_id]
72
+ fold_edges = [
73
+ eid for eid in edge_ids
74
+ if graph.edges[eid][2] in ('M', 'V')
75
+ ]
76
+
77
+ if len(fold_edges) < 4:
78
+ return True
79
+
80
+ m_count = sum(1 for eid in fold_edges if graph.edges[eid][2] == 'M')
81
+ v_count = sum(1 for eid in fold_edges if graph.edges[eid][2] == 'V')
82
+ return abs(m_count - v_count) == 2
83
+
84
+
85
+ def check_blb_at_vertex(vertex_id: int, graph: CreaseGraph) -> list[tuple[int, int]]:
86
+ """
87
+ Checks Big-Little-Big lemma at a single vertex.
88
+
89
+ BLB: if sector angle i is a strict local minimum (smaller than both neighbors),
90
+ the fold edges bounding that sector must have OPPOSITE MV assignments.
91
+
92
+ Returns list of (edge_a_id, edge_b_id) pairs where BLB is violated.
93
+ Empty list = no violations.
94
+ """
95
+ cyclic_edges = graph.get_cyclic_edges(vertex_id)
96
+ n = len(cyclic_edges)
97
+
98
+ if n < 4:
99
+ return []
100
+
101
+ sectors = _compute_sector_angles(vertex_id, graph)
102
+ violations = []
103
+
104
+ for i in range(n):
105
+ prev_sector = sectors[(i - 1) % n]
106
+ next_sector = sectors[(i + 1) % n]
107
+
108
+ if sectors[i] < prev_sector and sectors[i] < next_sector:
109
+ edge_a = cyclic_edges[i]
110
+ edge_b = cyclic_edges[(i + 1) % n]
111
+
112
+ assign_a = graph.edges[edge_a][2]
113
+ assign_b = graph.edges[edge_b][2]
114
+
115
+ if assign_a in ('M', 'V') and assign_b in ('M', 'V'):
116
+ if assign_a == assign_b:
117
+ violations.append((edge_a, edge_b))
118
+
119
+ return violations
120
+
121
+
122
+ def _angle_diff(a1: float, a2: float) -> float:
123
+ """Minimum angle difference between two directed lines (considering 180° symmetry)."""
124
+ diff = abs(a1 - a2) % np.pi
125
+ return min(diff, np.pi - diff)
126
+
127
+
128
+ def geometric_crease_coverage(
129
+ state: PaperState,
130
+ target_edges: list[dict],
131
+ tol_pos: float = 0.05,
132
+ tol_angle_deg: float = 5.0,
133
+ ) -> tuple[float, float, float]:
134
+ """
135
+ Computes how well the current crease pattern matches the target.
136
+
137
+ Args:
138
+ state: current paper state with crease graph
139
+ target_edges: list of {'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'}
140
+ tol_pos: position tolerance for midpoint matching
141
+ tol_angle_deg: angle tolerance in degrees for direction matching
142
+
143
+ Returns:
144
+ (coverage, economy, assignment_accuracy)
145
+ coverage: weighted fraction of target creases matched [0, 1];
146
+ 1.0 if position+assignment match, 0.5 if position matches but assignment doesn't
147
+ economy: penalty for excess creases [0, 1], 1.0 = no excess
148
+ assignment_accuracy: fraction of positionally matched edges that also have correct M/V assignment [0, 1];
149
+ returns 1.0 if no positional matches (vacuous case)
150
+ """
151
+ current_edges = state.crease_edges()
152
+ tol_angle_rad = np.deg2rad(tol_angle_deg)
153
+
154
+ total_score = 0.0
155
+ position_matches = 0
156
+ assignment_correct = 0
157
+
158
+ for target in target_edges:
159
+ tx1, ty1 = target['v1']
160
+ tx2, ty2 = target['v2']
161
+ t_mid = ((tx1 + tx2) / 2.0, (ty1 + ty2) / 2.0)
162
+ t_angle = np.arctan2(ty2 - ty1, tx2 - tx1)
163
+ t_assign = target.get('assignment', 'M')
164
+
165
+ for current in current_edges:
166
+ cx1, cy1 = current['v1']
167
+ cx2, cy2 = current['v2']
168
+ c_mid = ((cx1 + cx2) / 2.0, (cy1 + cy2) / 2.0)
169
+ c_angle = np.arctan2(cy2 - cy1, cx2 - cx1)
170
+ c_assign = current.get('assignment', 'M')
171
+
172
+ mid_dist = np.hypot(c_mid[0] - t_mid[0], c_mid[1] - t_mid[1])
173
+ angle_distance = _angle_diff(c_angle, t_angle)
174
+
175
+ if mid_dist <= tol_pos and angle_distance <= tol_angle_rad:
176
+ position_matches += 1
177
+ assign_match = (t_assign == c_assign)
178
+ if assign_match:
179
+ total_score += 1.0
180
+ assignment_correct += 1
181
+ else:
182
+ total_score += 0.5
183
+ break
184
+
185
+ coverage = total_score / max(len(target_edges), 1)
186
+ n_excess = max(0, len(current_edges) - len(target_edges))
187
+ economy = max(0.0, 1.0 - n_excess / max(len(target_edges), 1))
188
+ assignment_accuracy = (
189
+ assignment_correct / position_matches if position_matches > 0 else 1.0
190
+ )
191
+ return (coverage, economy, assignment_accuracy)
192
+
193
+
194
+ def check_degree_sanity(graph: CreaseGraph) -> float:
195
+ """
196
+ Checks that interior vertices have even degree (required for flat-foldability).
197
+
198
+ Returns:
199
+ Fraction of interior vertices with even degree [0, 1].
200
+ 1.0 = all interior vertices have even degree.
201
+ 0.0 = none do.
202
+ Returns 1.0 if there are no interior vertices (vacuous case).
203
+ """
204
+ interior = graph.interior_vertices()
205
+ if not interior:
206
+ return 1.0
207
+ even_count = sum(
208
+ 1 for vid in interior
209
+ if len(graph.vertex_edges[vid]) % 2 == 0
210
+ )
211
+ return even_count / len(interior)
212
+
213
+
214
+ def check_all_vertices(graph: CreaseGraph) -> dict:
215
+ """
216
+ Run all vertex-level checks on every interior vertex.
217
+
218
+ Returns dict with:
219
+ 'kawasaki': float # fraction of interior vertices passing Kawasaki [0,1]
220
+ 'maekawa': float # fraction passing Maekawa [0,1]
221
+ 'blb': float # fraction with no BLB violations [0,1]
222
+ 'n_interior': int # number of interior vertices checked
223
+ 'per_vertex': list[dict] # per-vertex details
224
+ """
225
+ interior = graph.interior_vertices()
226
+
227
+ if not interior:
228
+ return {
229
+ 'kawasaki': 1.0,
230
+ 'maekawa': 1.0,
231
+ 'blb': 1.0,
232
+ 'n_interior': 0,
233
+ 'per_vertex': [],
234
+ }
235
+
236
+ per_vertex = []
237
+ kaw_pass = 0
238
+ mae_pass = 0
239
+ blb_pass = 0
240
+
241
+ for vid in interior:
242
+ kaw_ok, kaw_val = check_kawasaki_at_vertex(vid, graph)
243
+ mae_ok = check_maekawa_at_vertex(vid, graph)
244
+ blb_violations = check_blb_at_vertex(vid, graph)
245
+ blb_ok = len(blb_violations) == 0
246
+
247
+ kaw_pass += int(kaw_ok)
248
+ mae_pass += int(mae_ok)
249
+ blb_pass += int(blb_ok)
250
+
251
+ per_vertex.append({
252
+ 'vertex_id': vid,
253
+ 'kawasaki_ok': kaw_ok,
254
+ 'kawasaki_error': kaw_val,
255
+ 'maekawa_ok': mae_ok,
256
+ 'blb_violations': blb_violations,
257
+ })
258
+
259
+ n = len(interior)
260
+ return {
261
+ 'kawasaki': kaw_pass / n,
262
+ 'maekawa': mae_pass / n,
263
+ 'blb': blb_pass / n,
264
+ 'n_interior': n,
265
+ 'per_vertex': per_vertex,
266
+ }
267
+
268
+
269
+ def target_crease_edges(target: dict) -> list[dict]:
270
+ """
271
+ Extract crease edges from a FOLD target dict as list of
272
+ {'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'} dicts.
273
+ """
274
+ verts = target['vertices_coords']
275
+ result = []
276
+ for i, (v1_idx, v2_idx) in enumerate(target['edges_vertices']):
277
+ assignment = target['edges_assignment'][i]
278
+ if assignment in ('M', 'V'):
279
+ result.append({
280
+ 'v1': tuple(verts[v1_idx]),
281
+ 'v2': tuple(verts[v2_idx]),
282
+ 'assignment': assignment,
283
+ })
284
+ return result
285
+
286
+
287
+ def compute_reward(
288
+ prev_state: PaperState,
289
+ action_result: dict,
290
+ new_state: PaperState,
291
+ target: dict,
292
+ step: int,
293
+ max_steps: int,
294
+ ) -> dict:
295
+ """
296
+ Compute the full reward dict for a fold action (lexicographically gated).
297
+
298
+ Args:
299
+ prev_state: PaperState BEFORE the action was applied
300
+ action_result: {'valid': bool, 'anchored': bool, 'duplicate': bool, ...}
301
+ new_state: PaperState AFTER the action was applied
302
+ target: FOLD target dict
303
+ step: current step index
304
+ max_steps: maximum steps in episode
305
+
306
+ Returns dict with keys:
307
+ format, anchored, novelty, kawasaki, maekawa, blb, degree_sanity,
308
+ progress, economy, assignment_accuracy, delta, regression,
309
+ completion, efficiency, total
310
+ """
311
+ r = {}
312
+
313
+ # GATE 1: Format — did the action parse and apply?
314
+ r['format'] = 1.0 if action_result.get('valid', False) else 0.0
315
+ if not r['format']:
316
+ r['total'] = -0.1
317
+ return r
318
+
319
+ # GATE 2: Structural sanity
320
+ r['anchored'] = 1.0 if action_result.get('anchored', False) else 0.3
321
+ r['novelty'] = 0.0 if action_result.get('duplicate', False) is True else 0.2
322
+
323
+ # LEVEL 3: Local flat-foldability
324
+ vertex_scores = check_all_vertices(new_state.graph)
325
+ r['kawasaki'] = vertex_scores['kawasaki']
326
+ r['maekawa'] = vertex_scores['maekawa']
327
+ r['blb'] = vertex_scores['blb']
328
+ r['degree_sanity'] = check_degree_sanity(new_state.graph)
329
+
330
+ # LEVEL 4: Progress (absolute + delta)
331
+ t_edges = target_crease_edges(target)
332
+ old_coverage, _, _ = geometric_crease_coverage(prev_state, t_edges)
333
+ new_coverage, economy, assignment_accuracy = geometric_crease_coverage(new_state, t_edges)
334
+
335
+ r['progress'] = new_coverage
336
+ r['economy'] = economy
337
+ r['assignment_accuracy'] = assignment_accuracy
338
+ r['delta'] = max(0.0, new_coverage - old_coverage)
339
+ r['regression'] = min(0.0, new_coverage - old_coverage)
340
+
341
+ # LEVEL 5: Completion bonus
342
+ all_valid = (
343
+ r['kawasaki'] == 1.0
344
+ and r['maekawa'] == 1.0
345
+ and r['blb'] == 1.0
346
+ )
347
+ r['completion'] = 10.0 if (r['progress'] > 0.9 and all_valid) else 0.0
348
+
349
+ # LEVEL 6: Efficiency — escalating step cost
350
+ r['efficiency'] = -0.01 * (1 + step / max_steps)
351
+
352
+ # Weighted total
353
+ r['total'] = (
354
+ 0.05 * r['anchored']
355
+ + 0.05 * r['novelty']
356
+ + 0.06 * r['kawasaki']
357
+ + 0.06 * r['maekawa']
358
+ + 0.04 * r['blb']
359
+ + 0.04 * r['degree_sanity']
360
+ + 0.25 * r['progress']
361
+ + 0.05 * r['economy']
362
+ + 0.05 * r['assignment_accuracy']
363
+ + 0.20 * r['delta']
364
+ + 0.10 * r['regression']
365
+ + r['completion']
366
+ + r['efficiency']
367
+ )
368
+ return r
369
+
370
+
371
+ def compute_terminal_reward(
372
+ state: PaperState,
373
+ target: dict,
374
+ max_steps: int,
375
+ ) -> dict:
376
+ """
377
+ Compute reward for the final state after a complete fold sequence.
378
+ Uses fresh PaperState as baseline and step = max_steps.
379
+ """
380
+ fake_result = {
381
+ 'valid': True,
382
+ 'anchored': True,
383
+ 'duplicate': False,
384
+ }
385
+ return compute_reward(
386
+ prev_state=PaperState(),
387
+ action_result=fake_result,
388
+ new_state=state,
389
+ target=target,
390
+ step=max_steps,
391
+ max_steps=max_steps,
392
+ )