""" 3D Shape Comparison Rewards (AlphaFold-inspired) Computes how close a folded origami shape is to a target 3D shape using: - Chamfer Distance: average nearest-neighbor distance between point clouds - Hausdorff Distance: worst-case misalignment - GDT-TS-like score: % of vertices within distance thresholds (for logging) - Bounding box IoU: does the folded shape fit the target dimensions? These metrics are fast (<1ms for typical origami meshes with 10-100 vertices) and can be computed per-step or at episode end. Usage: from env.shape_reward import compute_3d_shape_reward reward = compute_3d_shape_reward( predicted_vertices=[[0,0,0], [1,0,0], [1,1,0], [0,1,0.5]], target_vertices=[[0,0,0], [1,0,0], [1,1,0], [0,1,0]], ) # reward = {'chamfer': 0.03, 'hausdorff': 0.5, 'gdt_1': 0.75, ...} """ from __future__ import annotations import numpy as np from scipy.spatial import cKDTree from scipy.spatial.distance import directed_hausdorff def chamfer_distance(P: np.ndarray, Q: np.ndarray) -> float: """ Symmetric Chamfer Distance between two point clouds. CD(P,Q) = (1/|P|) * sum_p(min_q ||p-q||^2) + (1/|Q|) * sum_q(min_p ||q-p||^2) Lower = better. 0 = identical shapes. """ if len(P) == 0 or len(Q) == 0: return float('inf') tree_P = cKDTree(P) tree_Q = cKDTree(Q) # P -> Q distances d_pq, _ = tree_Q.query(P) # Q -> P distances d_qp, _ = tree_P.query(Q) return float(np.mean(d_pq ** 2) + np.mean(d_qp ** 2)) def hausdorff_dist(P: np.ndarray, Q: np.ndarray) -> float: """ Symmetric Hausdorff Distance — max of min distances. Captures worst-case misalignment. """ if len(P) == 0 or len(Q) == 0: return float('inf') d_forward = directed_hausdorff(P, Q)[0] d_backward = directed_hausdorff(Q, P)[0] return float(max(d_forward, d_backward)) def gdt_ts_score(P: np.ndarray, Q: np.ndarray, thresholds: tuple = (0.01, 0.02, 0.05, 0.10)) -> dict: """ GDT-TS-like score: fraction of predicted vertices within distance thresholds of target. Inspired by protein structure prediction metrics. For each threshold t, compute the fraction of vertices in P that have a nearest neighbor in Q within distance t. Returns dict like: {'gdt_1': 0.8, 'gdt_2': 0.9, 'gdt_5': 1.0, 'gdt_10': 1.0, 'gdt_avg': 0.925} """ if len(P) == 0 or len(Q) == 0: return {f'gdt_{int(t*100)}': 0.0 for t in thresholds} tree_Q = cKDTree(Q) distances, _ = tree_Q.query(P) scores = {} for t in thresholds: key = f'gdt_{int(t * 100)}' scores[key] = float(np.mean(distances <= t)) scores['gdt_avg'] = float(np.mean(list(scores.values()))) return scores def bounding_box_iou(P: np.ndarray, Q: np.ndarray) -> float: """ 3D bounding box Intersection over Union. Computes axis-aligned bounding boxes of both point clouds and returns their volumetric IoU [0, 1]. """ if len(P) == 0 or len(Q) == 0: return 0.0 # Ensure 3D if P.shape[1] == 2: P = np.column_stack([P, np.zeros(len(P))]) if Q.shape[1] == 2: Q = np.column_stack([Q, np.zeros(len(Q))]) p_min, p_max = P.min(axis=0), P.max(axis=0) q_min, q_max = Q.min(axis=0), Q.max(axis=0) # Intersection inter_min = np.maximum(p_min, q_min) inter_max = np.minimum(p_max, q_max) inter_dims = np.maximum(0, inter_max - inter_min) inter_vol = float(np.prod(inter_dims)) # Union p_vol = float(np.prod(np.maximum(1e-10, p_max - p_min))) q_vol = float(np.prod(np.maximum(1e-10, q_max - q_min))) union_vol = p_vol + q_vol - inter_vol if union_vol < 1e-15: return 0.0 return inter_vol / union_vol def lddt_like_score(P: np.ndarray, Q: np.ndarray, cutoff: float = 0.15, thresholds: tuple = (0.005, 0.01, 0.02, 0.04)) -> float: """ lDDT-like (Local Distance Difference Test) score for origami. Inspired by AlphaFold's lDDT metric. For each pair of vertices that are within `cutoff` distance in the target shape Q, check if their pairwise distance is preserved in the predicted shape P within various thresholds. This is superposition-free — it doesn't require alignment. Measures local fold accuracy: are nearby vertices still in the right relative positions? Returns score in [0, 1]. Higher = better. """ n = min(len(P), len(Q)) if n < 2: return 1.0 P_n = P[:n] Q_n = Q[:n] # Compute pairwise distances in both shapes # Only consider pairs within cutoff in the target Q_dists = np.linalg.norm(Q_n[:, None, :] - Q_n[None, :, :], axis=-1) P_dists = np.linalg.norm(P_n[:, None, :] - P_n[None, :, :], axis=-1) mask = (Q_dists < cutoff) & (Q_dists > 1e-10) # exclude self-pairs if not np.any(mask): return 1.0 dist_diffs = np.abs(P_dists[mask] - Q_dists[mask]) # For each threshold, fraction of pairs preserved scores = [float(np.mean(dist_diffs < t)) for t in thresholds] return float(np.mean(scores)) def compute_3d_shape_reward( predicted_vertices: list | np.ndarray, target_vertices: list | np.ndarray, weights: dict | None = None, ) -> dict: """ Compute all 3D shape comparison metrics between predicted and target shapes. Args: predicted_vertices: Nx2 or Nx3 array of vertex positions (current fold state) target_vertices: Mx2 or Mx3 array of vertex positions (target shape) weights: optional weight dict for composite score Returns dict with all metrics + weighted 'shape_total' score. """ P = np.asarray(predicted_vertices, dtype=np.float64) Q = np.asarray(target_vertices, dtype=np.float64) # Ensure 3D if P.ndim == 1: P = P.reshape(-1, 2 if len(P) % 2 == 0 else 3) if Q.ndim == 1: Q = Q.reshape(-1, 2 if len(Q) % 2 == 0 else 3) if P.shape[1] == 2: P = np.column_stack([P, np.zeros(len(P))]) if Q.shape[1] == 2: Q = np.column_stack([Q, np.zeros(len(Q))]) w = weights or { 'chamfer': 5.0, 'hausdorff': 1.0, 'bbox_iou': 3.0, 'lddt': 2.0, } result = {} # Core metrics cd = chamfer_distance(P, Q) result['chamfer'] = cd result['chamfer_score'] = max(0.0, 1.0 - cd * 10.0) # normalized to ~[0,1] hd = hausdorff_dist(P, Q) result['hausdorff'] = hd result['hausdorff_score'] = max(0.0, 1.0 - hd * 2.0) result['bbox_iou'] = bounding_box_iou(P, Q) result['lddt'] = lddt_like_score(P, Q) # GDT-TS scores for logging gdt = gdt_ts_score(P, Q) result.update(gdt) # Composite score result['shape_total'] = ( w.get('chamfer', 5.0) * result['chamfer_score'] + w.get('hausdorff', 1.0) * result['hausdorff_score'] + w.get('bbox_iou', 3.0) * result['bbox_iou'] + w.get('lddt', 2.0) * result['lddt'] ) return result