optigami / env /rewards.py
sissississi's picture
Add 3D shape comparison reward module (AlphaFold-inspired)
9aba971
import json
import numpy as np
from .verifier import check_all_vertices, check_degree_sanity, geometric_crease_coverage
from .paper_state import PaperState
from .shape_reward import compute_3d_shape_reward
def load_target(target_path: str) -> dict:
"""Load a .fold target file and return it as a dict."""
with open(target_path) as f:
return json.load(f)
def target_crease_edges(target: dict) -> list[dict]:
"""
Extract crease edges from a FOLD target dict as list of
{'v1': (x1,y1), 'v2': (x2,y2), 'assignment': 'M'|'V'} dicts.
"""
verts = target['vertices_coords']
result = []
for i, (v1_idx, v2_idx) in enumerate(target['edges_vertices']):
assignment = target['edges_assignment'][i]
if assignment in ('M', 'V'):
result.append({
'v1': tuple(verts[v1_idx]),
'v2': tuple(verts[v2_idx]),
'assignment': assignment,
})
return result
def compute_reward(
prev_state: PaperState,
action_result: dict,
new_state: PaperState,
target: dict,
step: int,
max_steps: int,
) -> dict:
"""
Compute the full reward dict for a fold action (lexicographically gated).
Args:
prev_state: PaperState BEFORE the action was applied
action_result: {'valid': bool, 'anchored': bool, 'duplicate': bool, ...}
new_state: PaperState AFTER the action was applied
target: FOLD target dict
step: current step index
max_steps: maximum steps in episode
Returns dict with keys:
format, anchored, novelty, kawasaki, maekawa, blb, degree_sanity,
progress, economy, assignment_accuracy, delta, regression,
completion, efficiency, total
"""
r = {}
# GATE 1: Format — did the action parse and apply?
r['format'] = 1.0 if action_result.get('valid', False) else 0.0
if not r['format']:
r['total'] = -0.1
return r
# GATE 2: Structural sanity
r['anchored'] = 1.0 if action_result.get('anchored', False) else 0.3
r['novelty'] = 0.0 if action_result.get('duplicate', False) is True else 0.2
# LEVEL 3: Local flat-foldability
vertex_scores = check_all_vertices(new_state.graph)
r['kawasaki'] = vertex_scores['kawasaki']
r['maekawa'] = vertex_scores['maekawa']
r['blb'] = vertex_scores['blb']
r['degree_sanity'] = check_degree_sanity(new_state.graph)
# LEVEL 4: Progress (absolute + delta)
t_edges = target_crease_edges(target)
old_coverage, _, _ = geometric_crease_coverage(prev_state, t_edges)
new_coverage, economy, assignment_accuracy = geometric_crease_coverage(new_state, t_edges)
r['progress'] = new_coverage
r['economy'] = economy
r['assignment_accuracy'] = assignment_accuracy
r['delta'] = max(0.0, new_coverage - old_coverage)
r['regression'] = min(0.0, new_coverage - old_coverage)
# LEVEL 5: 3D Shape comparison (AlphaFold-inspired)
# If the target has 3D vertex data, compare the current fold state's
# vertex positions against the target's folded shape.
r['shape_score'] = 0.0
target_3d = target.get('vertices_coords_folded') # 3D target shape
if target_3d is not None:
# Current state vertices (2D for now; z=0 for flat creases)
current_verts = []
for vid, (x, y) in new_state.graph.vertices.items():
current_verts.append([x, y, 0.0])
if current_verts:
shape_result = compute_3d_shape_reward(current_verts, target_3d)
r['chamfer'] = shape_result['chamfer']
r['chamfer_score'] = shape_result['chamfer_score']
r['hausdorff'] = shape_result['hausdorff']
r['bbox_iou'] = shape_result['bbox_iou']
r['lddt'] = shape_result['lddt']
r['shape_score'] = shape_result['shape_total']
r.update({k: v for k, v in shape_result.items() if k.startswith('gdt_')})
# LEVEL 6: Completion bonus
all_valid = (
r['kawasaki'] == 1.0
and r['maekawa'] == 1.0
and r['blb'] == 1.0
)
r['completion'] = 10.0 if (r['progress'] > 0.9 and all_valid) else 0.0
# LEVEL 7: Efficiency — escalating step cost
r['efficiency'] = -0.01 * (1 + step / max_steps)
# Weighted total (2D crease matching + 3D shape comparison)
r['total'] = (
# 2D crease pattern matching (existing)
0.05 * r['anchored']
+ 0.05 * r['novelty']
+ 0.06 * r['kawasaki']
+ 0.06 * r['maekawa']
+ 0.04 * r['blb']
+ 0.04 * r['degree_sanity']
+ 0.15 * r['progress']
+ 0.05 * r['economy']
+ 0.05 * r['assignment_accuracy']
+ 0.10 * r['delta']
+ 0.05 * r['regression']
# 3D shape comparison (new — AlphaFold-inspired)
+ 0.15 * r['shape_score']
# Bonuses and penalties
+ r['completion']
+ r['efficiency']
)
return r
def compute_terminal_reward(
state: PaperState,
target: dict,
max_steps: int,
) -> dict:
"""
Compute reward for the final state after a complete fold sequence.
Uses fresh PaperState as baseline and step = max_steps.
"""
fake_result = {
'valid': True,
'anchored': True,
'duplicate': False,
}
return compute_reward(
prev_state=PaperState(),
action_result=fake_result,
new_state=state,
target=target,
step=max_steps,
max_steps=max_steps,
)