""" Generate paper figures from eval_data/ outputs: - figures/rollout_mse_curve.pdf — MSE(t) with std shading for 4 scenarios - figures/conservation_drift.pdf — horizontal px error + KE error (in-distribution) - figures/collision_decomp.pdf — collision vs free-flight bar chart + quantitative table """ import json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from pathlib import Path EVAL = Path('/home/alexw/Projects/physics-llm-paper/eval_data') FIGS = Path('/home/alexw/Projects/physics-llm-paper/figures') FIGS.mkdir(exist_ok=True) COLORS = { 'Constraint': '#2196F3', 'Stacking': '#FF9800', 'Collision': '#4CAF50', 'OOD-novel': '#F44336', } def plot_rollout(): data = json.loads((EVAL / 'rollout_mse.json').read_text()) fig, ax = plt.subplots(figsize=(5.5, 3.5)) for scen, row in data.items(): curve = np.array(row['mean_mse_curve']) rmse = np.sqrt(np.maximum(curve, 0)) cat = row['category'] label = f"{scen.replace('_',' ').title()} ({cat})" color = COLORS.get(cat, '#888') steps = np.arange(1, len(rmse) + 1) ax.plot(steps, rmse, color=color, label=label, linewidth=1.8) # std shading (convert MSE std to RMSE std via √(MSE+std) - √MSE approx) if 'std_mse_curve' in row: std = np.array(row['std_mse_curve']) # use per-scene RMSE std if per_scene_curves available if 'per_scene_curves' in row: per_sc = np.array(row['per_scene_curves']) per_rmse = np.sqrt(np.maximum(per_sc, 0)) rmse_std = np.nanstd(per_rmse, axis=0) else: rmse_std = std / (2 * rmse + 1e-6) # delta method ax.fill_between(steps, np.maximum(rmse - rmse_std, 0), rmse + rmse_std, color=color, alpha=0.15) ax.set_xlabel('Rollout step $t$', fontsize=10) ax.set_ylabel('Avg RMSE (px)', fontsize=10) ax.set_title('Multi-step autoregressive rollout error', fontsize=10) ax.legend(fontsize=8, loc='upper left') ax.set_xlim(1, max(len(v['mean_mse_curve']) for v in data.values())) ax.set_yscale('log') ax.grid(True, alpha=0.3) fig.tight_layout() for ext in ('pdf', 'png'): fig.savefig(FIGS / f'rollout_mse_curve.{ext}', dpi=200, bbox_inches='tight') plt.close() print("Saved rollout_mse_curve.pdf/.png") def plot_conservation(): p = EVAL / 'conservation.json' if not p.exists(): print("conservation.json not found, skipping") return data = json.loads(p.read_text()) px_curve = np.array(data['px_err_curve']) px_std = np.array(data.get('px_err_std_curve', np.zeros_like(px_curve))) mean_ke = data.get('mean_ke_err_free_flight', None) std_ke = data.get('std_ke_err_free_flight', 0.0) steps = np.arange(1, len(px_curve) + 1) fig, axes = plt.subplots(1, 2, figsize=(7.5, 3.2)) # Left: horizontal momentum error curve axes[0].plot(steps, px_curve * 100, color='#2196F3', linewidth=1.8, label='mean px error') axes[0].fill_between(steps, np.maximum((px_curve - px_std) * 100, 0), (px_curve + px_std) * 100, color='#2196F3', alpha=0.15) axes[0].set_xlabel('Rollout step', fontsize=10) axes[0].set_ylabel('Horizontal momentum error (%)', fontsize=9) axes[0].set_title('Horizontal momentum\n(gravity-free axis)', fontsize=9) axes[0].set_ylim(0, 100) axes[0].grid(True, alpha=0.3) # Right: KE error (scalar, show as bar with error bar) if mean_ke is not None: axes[1].bar(['Free-flight KE\nerror'], [mean_ke * 100], yerr=[std_ke * 100], color='#F44336', alpha=0.75, width=0.4, capsize=8) axes[1].set_ylabel('|KE_pred − KE_gt| / KE_gt (%)', fontsize=9) axes[1].set_title('Kinetic energy error\n(free-flight frames only)', fontsize=9) axes[1].set_ylim(0, 100) axes[1].grid(True, alpha=0.3, axis='y') fig.tight_layout() for ext in ('pdf', 'png'): fig.savefig(FIGS / f'conservation_drift.{ext}', dpi=200, bbox_inches='tight') plt.close() print("Saved conservation_drift.pdf/.png") def plot_collision_decomp(): data = json.loads((EVAL / 'collision_decomp.json').read_text()) cats = data['per_category'] CAT_ORDER = ['Collision', 'Stacking', 'Ramp', 'Constraint', 'Minigame', 'Complex'] cats_present = [c for c in CAT_ORDER if c in cats] col_frac = [cats[c]['col_frac'] * 100 for c in cats_present] col_lin = [cats[c]['col_lin_mse'] for c in cats_present] flight_lin = [cats[c]['flight_lin_mse'] for c in cats_present] x = np.arange(len(cats_present)) w = 0.35 fig, axes = plt.subplots(1, 2, figsize=(9, 3.2)) # Left: fraction of collision frames per category axes[0].bar(x, col_frac, color='#F44336', alpha=0.8) axes[0].set_xticks(x) axes[0].set_xticklabels(cats_present, fontsize=8) axes[0].set_ylabel('Collision frames (%)', fontsize=9) axes[0].set_title('Fraction of collision frames\nper category', fontsize=9) axes[0].set_ylim(0, 110) axes[0].grid(True, alpha=0.3, axis='y') # Right: linear extrap MSE on collision vs free-flight frames axes[1].bar(x - w/2, col_lin, width=w, label='Collision frames', color='#F44336', alpha=0.8) axes[1].bar(x + w/2, flight_lin, width=w, label='Free-flight frames', color='#4CAF50', alpha=0.8) axes[1].set_xticks(x) axes[1].set_xticklabels(cats_present, fontsize=8) axes[1].set_ylabel('Linear extrap MSE (px²)', fontsize=9) axes[1].set_title('Prediction difficulty:\ncollision vs. free-flight', fontsize=9) axes[1].legend(fontsize=8) axes[1].set_yscale('log') axes[1].grid(True, alpha=0.3, axis='y') fig.tight_layout() for ext in ('pdf', 'png'): fig.savefig(FIGS / f'collision_decomp.{ext}', dpi=200, bbox_inches='tight') plt.close() print("Saved collision_decomp.pdf/.png") if __name__ == '__main__': plot_rollout() plot_conservation() plot_collision_decomp() print("Done — all figures written to", FIGS)