Spaces:
Running
Running
| import json | |
| import numpy as np | |
| from .verifier import check_all_vertices, check_degree_sanity, geometric_crease_coverage | |
| from .paper_state import PaperState | |
| from .shape_reward import compute_3d_shape_reward | |
| def load_target(target_path: str) -> dict: | |
| """Load a .fold target file and return it as a dict.""" | |
| with open(target_path) as f: | |
| return json.load(f) | |
| def target_crease_edges(target: dict) -> list[dict]: | |
| """ | |
| Extract crease edges from a FOLD target dict as list of | |
| {'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'} dicts. | |
| """ | |
| verts = target['vertices_coords'] | |
| result = [] | |
| for i, (v1_idx, v2_idx) in enumerate(target['edges_vertices']): | |
| assignment = target['edges_assignment'][i] | |
| if assignment in ('M', 'V'): | |
| result.append({ | |
| 'v1': tuple(verts[v1_idx]), | |
| 'v2': tuple(verts[v2_idx]), | |
| 'assignment': assignment, | |
| }) | |
| return result | |
| def compute_reward( | |
| prev_state: PaperState, | |
| action_result: dict, | |
| new_state: PaperState, | |
| target: dict, | |
| step: int, | |
| max_steps: int, | |
| ) -> dict: | |
| """ | |
| Compute the full reward dict for a fold action (lexicographically gated). | |
| Args: | |
| prev_state: PaperState BEFORE the action was applied | |
| action_result: {'valid': bool, 'anchored': bool, 'duplicate': bool, ...} | |
| new_state: PaperState AFTER the action was applied | |
| target: FOLD target dict | |
| step: current step index | |
| max_steps: maximum steps in episode | |
| Returns dict with keys: | |
| format, anchored, novelty, kawasaki, maekawa, blb, degree_sanity, | |
| progress, economy, assignment_accuracy, delta, regression, | |
| completion, efficiency, total | |
| """ | |
| r = {} | |
| # GATE 1: Format — did the action parse and apply? | |
| r['format'] = 1.0 if action_result.get('valid', False) else 0.0 | |
| if not r['format']: | |
| r['total'] = -0.1 | |
| return r | |
| # GATE 2: Structural sanity | |
| r['anchored'] = 1.0 if action_result.get('anchored', False) else 0.3 | |
| r['novelty'] = 0.0 if action_result.get('duplicate', False) is True else 0.2 | |
| # LEVEL 3: Local flat-foldability | |
| vertex_scores = check_all_vertices(new_state.graph) | |
| r['kawasaki'] = vertex_scores['kawasaki'] | |
| r['maekawa'] = vertex_scores['maekawa'] | |
| r['blb'] = vertex_scores['blb'] | |
| r['degree_sanity'] = check_degree_sanity(new_state.graph) | |
| # LEVEL 4: Progress (absolute + delta) | |
| t_edges = target_crease_edges(target) | |
| old_coverage, _, _ = geometric_crease_coverage(prev_state, t_edges) | |
| new_coverage, economy, assignment_accuracy = geometric_crease_coverage(new_state, t_edges) | |
| r['progress'] = new_coverage | |
| r['economy'] = economy | |
| r['assignment_accuracy'] = assignment_accuracy | |
| r['delta'] = max(0.0, new_coverage - old_coverage) | |
| r['regression'] = min(0.0, new_coverage - old_coverage) | |
| # LEVEL 5: 3D Shape comparison (AlphaFold-inspired) | |
| # If the target has 3D vertex data, compare the current fold state's | |
| # vertex positions against the target's folded shape. | |
| r['shape_score'] = 0.0 | |
| target_3d = target.get('vertices_coords_folded') # 3D target shape | |
| if target_3d is not None: | |
| # Current state vertices (2D for now; z=0 for flat creases) | |
| current_verts = [] | |
| for vid, (x, y) in new_state.graph.vertices.items(): | |
| current_verts.append([x, y, 0.0]) | |
| if current_verts: | |
| shape_result = compute_3d_shape_reward(current_verts, target_3d) | |
| r['chamfer'] = shape_result['chamfer'] | |
| r['chamfer_score'] = shape_result['chamfer_score'] | |
| r['hausdorff'] = shape_result['hausdorff'] | |
| r['bbox_iou'] = shape_result['bbox_iou'] | |
| r['lddt'] = shape_result['lddt'] | |
| r['shape_score'] = shape_result['shape_total'] | |
| r.update({k: v for k, v in shape_result.items() if k.startswith('gdt_')}) | |
| # LEVEL 6: Completion bonus | |
| all_valid = ( | |
| r['kawasaki'] == 1.0 | |
| and r['maekawa'] == 1.0 | |
| and r['blb'] == 1.0 | |
| ) | |
| r['completion'] = 10.0 if (r['progress'] > 0.9 and all_valid) else 0.0 | |
| # LEVEL 7: Efficiency — escalating step cost | |
| r['efficiency'] = -0.01 * (1 + step / max_steps) | |
| # Weighted total (2D crease matching + 3D shape comparison) | |
| r['total'] = ( | |
| # 2D crease pattern matching (existing) | |
| 0.05 * r['anchored'] | |
| + 0.05 * r['novelty'] | |
| + 0.06 * r['kawasaki'] | |
| + 0.06 * r['maekawa'] | |
| + 0.04 * r['blb'] | |
| + 0.04 * r['degree_sanity'] | |
| + 0.15 * r['progress'] | |
| + 0.05 * r['economy'] | |
| + 0.05 * r['assignment_accuracy'] | |
| + 0.10 * r['delta'] | |
| + 0.05 * r['regression'] | |
| # 3D shape comparison (new — AlphaFold-inspired) | |
| + 0.15 * r['shape_score'] | |
| # Bonuses and penalties | |
| + r['completion'] | |
| + r['efficiency'] | |
| ) | |
| return r | |
| def compute_terminal_reward( | |
| state: PaperState, | |
| target: dict, | |
| max_steps: int, | |
| ) -> dict: | |
| """ | |
| Compute reward for the final state after a complete fold sequence. | |
| Uses fresh PaperState as baseline and step = max_steps. | |
| """ | |
| fake_result = { | |
| 'valid': True, | |
| 'anchored': True, | |
| 'duplicate': False, | |
| } | |
| return compute_reward( | |
| prev_state=PaperState(), | |
| action_result=fake_result, | |
| new_state=state, | |
| target=target, | |
| step=max_steps, | |
| max_steps=max_steps, | |
| ) | |