optigami / env /shape_reward.py
sissississi's picture
Add 3D shape comparison reward module (AlphaFold-inspired)
9aba971
"""
3D Shape Comparison Rewards (AlphaFold-inspired)
Computes how close a folded origami shape is to a target 3D shape using:
- Chamfer Distance: average nearest-neighbor distance between point clouds
- Hausdorff Distance: worst-case misalignment
- GDT-TS-like score: % of vertices within distance thresholds (for logging)
- Bounding box IoU: does the folded shape fit the target dimensions?
These metrics are fast (<1ms for typical origami meshes with 10-100 vertices)
and can be computed per-step or at episode end.
Usage:
from env.shape_reward import compute_3d_shape_reward
reward = compute_3d_shape_reward(
predicted_vertices=[[0,0,0], [1,0,0], [1,1,0], [0,1,0.5]],
target_vertices=[[0,0,0], [1,0,0], [1,1,0], [0,1,0]],
)
# reward = {'chamfer': 0.03, 'hausdorff': 0.5, 'gdt_1': 0.75, ...}
"""
from __future__ import annotations
import numpy as np
from scipy.spatial import cKDTree
from scipy.spatial.distance import directed_hausdorff
def chamfer_distance(P: np.ndarray, Q: np.ndarray) -> float:
"""
Symmetric Chamfer Distance between two point clouds.
CD(P,Q) = (1/|P|) * sum_p(min_q ||p-q||^2) + (1/|Q|) * sum_q(min_p ||q-p||^2)
Lower = better. 0 = identical shapes.
"""
if len(P) == 0 or len(Q) == 0:
return float('inf')
tree_P = cKDTree(P)
tree_Q = cKDTree(Q)
# P -> Q distances
d_pq, _ = tree_Q.query(P)
# Q -> P distances
d_qp, _ = tree_P.query(Q)
return float(np.mean(d_pq ** 2) + np.mean(d_qp ** 2))
def hausdorff_dist(P: np.ndarray, Q: np.ndarray) -> float:
"""
Symmetric Hausdorff Distance — max of min distances.
Captures worst-case misalignment.
"""
if len(P) == 0 or len(Q) == 0:
return float('inf')
d_forward = directed_hausdorff(P, Q)[0]
d_backward = directed_hausdorff(Q, P)[0]
return float(max(d_forward, d_backward))
def gdt_ts_score(P: np.ndarray, Q: np.ndarray, thresholds: tuple = (0.01, 0.02, 0.05, 0.10)) -> dict:
"""
GDT-TS-like score: fraction of predicted vertices within distance thresholds of target.
Inspired by protein structure prediction metrics. For each threshold t,
compute the fraction of vertices in P that have a nearest neighbor in Q
within distance t.
Returns dict like: {'gdt_1': 0.8, 'gdt_2': 0.9, 'gdt_5': 1.0, 'gdt_10': 1.0, 'gdt_avg': 0.925}
"""
if len(P) == 0 or len(Q) == 0:
return {f'gdt_{int(t*100)}': 0.0 for t in thresholds}
tree_Q = cKDTree(Q)
distances, _ = tree_Q.query(P)
scores = {}
for t in thresholds:
key = f'gdt_{int(t * 100)}'
scores[key] = float(np.mean(distances <= t))
scores['gdt_avg'] = float(np.mean(list(scores.values())))
return scores
def bounding_box_iou(P: np.ndarray, Q: np.ndarray) -> float:
"""
3D bounding box Intersection over Union.
Computes axis-aligned bounding boxes of both point clouds
and returns their volumetric IoU [0, 1].
"""
if len(P) == 0 or len(Q) == 0:
return 0.0
# Ensure 3D
if P.shape[1] == 2:
P = np.column_stack([P, np.zeros(len(P))])
if Q.shape[1] == 2:
Q = np.column_stack([Q, np.zeros(len(Q))])
p_min, p_max = P.min(axis=0), P.max(axis=0)
q_min, q_max = Q.min(axis=0), Q.max(axis=0)
# Intersection
inter_min = np.maximum(p_min, q_min)
inter_max = np.minimum(p_max, q_max)
inter_dims = np.maximum(0, inter_max - inter_min)
inter_vol = float(np.prod(inter_dims))
# Union
p_vol = float(np.prod(np.maximum(1e-10, p_max - p_min)))
q_vol = float(np.prod(np.maximum(1e-10, q_max - q_min)))
union_vol = p_vol + q_vol - inter_vol
if union_vol < 1e-15:
return 0.0
return inter_vol / union_vol
def lddt_like_score(P: np.ndarray, Q: np.ndarray, cutoff: float = 0.15, thresholds: tuple = (0.005, 0.01, 0.02, 0.04)) -> float:
"""
lDDT-like (Local Distance Difference Test) score for origami.
Inspired by AlphaFold's lDDT metric. For each pair of vertices that are
within `cutoff` distance in the target shape Q, check if their pairwise
distance is preserved in the predicted shape P within various thresholds.
This is superposition-free — it doesn't require alignment.
Measures local fold accuracy: are nearby vertices still in the right relative positions?
Returns score in [0, 1]. Higher = better.
"""
n = min(len(P), len(Q))
if n < 2:
return 1.0
P_n = P[:n]
Q_n = Q[:n]
# Compute pairwise distances in both shapes
# Only consider pairs within cutoff in the target
Q_dists = np.linalg.norm(Q_n[:, None, :] - Q_n[None, :, :], axis=-1)
P_dists = np.linalg.norm(P_n[:, None, :] - P_n[None, :, :], axis=-1)
mask = (Q_dists < cutoff) & (Q_dists > 1e-10) # exclude self-pairs
if not np.any(mask):
return 1.0
dist_diffs = np.abs(P_dists[mask] - Q_dists[mask])
# For each threshold, fraction of pairs preserved
scores = [float(np.mean(dist_diffs < t)) for t in thresholds]
return float(np.mean(scores))
def compute_3d_shape_reward(
predicted_vertices: list | np.ndarray,
target_vertices: list | np.ndarray,
weights: dict | None = None,
) -> dict:
"""
Compute all 3D shape comparison metrics between predicted and target shapes.
Args:
predicted_vertices: Nx2 or Nx3 array of vertex positions (current fold state)
target_vertices: Mx2 or Mx3 array of vertex positions (target shape)
weights: optional weight dict for composite score
Returns dict with all metrics + weighted 'shape_total' score.
"""
P = np.asarray(predicted_vertices, dtype=np.float64)
Q = np.asarray(target_vertices, dtype=np.float64)
# Ensure 3D
if P.ndim == 1:
P = P.reshape(-1, 2 if len(P) % 2 == 0 else 3)
if Q.ndim == 1:
Q = Q.reshape(-1, 2 if len(Q) % 2 == 0 else 3)
if P.shape[1] == 2:
P = np.column_stack([P, np.zeros(len(P))])
if Q.shape[1] == 2:
Q = np.column_stack([Q, np.zeros(len(Q))])
w = weights or {
'chamfer': 5.0,
'hausdorff': 1.0,
'bbox_iou': 3.0,
'lddt': 2.0,
}
result = {}
# Core metrics
cd = chamfer_distance(P, Q)
result['chamfer'] = cd
result['chamfer_score'] = max(0.0, 1.0 - cd * 10.0) # normalized to ~[0,1]
hd = hausdorff_dist(P, Q)
result['hausdorff'] = hd
result['hausdorff_score'] = max(0.0, 1.0 - hd * 2.0)
result['bbox_iou'] = bounding_box_iou(P, Q)
result['lddt'] = lddt_like_score(P, Q)
# GDT-TS scores for logging
gdt = gdt_ts_score(P, Q)
result.update(gdt)
# Composite score
result['shape_total'] = (
w.get('chamfer', 5.0) * result['chamfer_score']
+ w.get('hausdorff', 1.0) * result['hausdorff_score']
+ w.get('bbox_iou', 3.0) * result['bbox_iou']
+ w.get('lddt', 2.0) * result['lddt']
)
return result