sissississi's picture
Add RL training environment with OpenEnv backend
bc52096
"""Shape matching for reward computation.
Computes similarity between the LLM's folded shape and the target shape.
Like AlphaFold's RMSD but for origami vertex positions.
"""
import numpy as np
from scipy.spatial.distance import cdist
def compute_shape_match(
predicted: np.ndarray,
target: np.ndarray,
) -> float:
"""Compute shape similarity between predicted and target positions.
Uses chamfer distance normalized by bounding box diagonal.
Aligns shapes by centering before comparison.
Args:
predicted: (N, 3) predicted vertex positions.
target: (M, 3) target vertex positions.
Returns:
Similarity score in [0, 1]. 1.0 = perfect match.
"""
if len(predicted) == 0 or len(target) == 0:
return 0.0
# Center both point clouds
pred_centered = predicted - predicted.mean(axis=0)
target_centered = target - target.mean(axis=0)
# Try multiple rotations and pick best match
best_score = 0.0
for rotation in _get_alignment_rotations():
rotated = pred_centered @ rotation.T
score = _chamfer_similarity(rotated, target_centered)
best_score = max(best_score, score)
return best_score
def _chamfer_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Chamfer distance converted to similarity score."""
d = cdist(a, b)
# Forward: for each point in a, min distance to b
forward = d.min(axis=1).mean()
# Backward: for each point in b, min distance to a
backward = d.min(axis=0).mean()
chamfer = (forward + backward) / 2.0
# Normalize by bounding box diagonal of target
all_pts = np.vstack([a, b])
bbox_diag = np.linalg.norm(all_pts.max(axis=0) - all_pts.min(axis=0))
if bbox_diag < 1e-12:
return 1.0 if chamfer < 1e-12 else 0.0
similarity = max(0.0, 1.0 - chamfer / bbox_diag)
return similarity
def _get_alignment_rotations() -> list[np.ndarray]:
"""Generate rotation matrices for alignment search.
Identity + 90 deg rotations around each axis + mirrors (15 total).
"""
I = np.eye(3)
rotations = [I]
# 90 deg rotations around Z axis
for k in range(1, 4):
angle = k * np.pi / 2
c, s = np.cos(angle), np.sin(angle)
rotations.append(np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]))
# 90 deg rotations around X axis
for k in range(1, 4):
angle = k * np.pi / 2
c, s = np.cos(angle), np.sin(angle)
rotations.append(np.array([[1, 0, 0], [0, c, -s], [0, s, c]]))
# 90 deg rotations around Y axis
for k in range(1, 4):
angle = k * np.pi / 2
c, s = np.cos(angle), np.sin(angle)
rotations.append(np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]))
# Mirrors
rotations.append(np.diag([-1.0, 1.0, 1.0]))
rotations.append(np.diag([1.0, -1.0, 1.0]))
rotations.append(np.diag([1.0, 1.0, -1.0]))
return rotations