"""Publication-quality figures for DPA paper.""" import json import matplotlib.pyplot as plt import matplotlib import numpy as np from pathlib import Path matplotlib.rcParams.update({ "font.size": 12, "font.family": "serif", "axes.labelsize": 14, "axes.titlesize": 15, "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 10, "figure.dpi": 150, }) COLORS = { "full_transformer": "#2196F3", "pure_linear": "#FF9800", "uniform_hybrid": "#4CAF50", "dpa": "#E91E63", "dpa_fixed": "#9C27B0", } def plot_accuracy_vs_flops(results_path, save_path="figures/accuracy_vs_flops.pdf"): """Main figure: accuracy vs compute tradeoff.""" with open(results_path) as f: results = json.load(f) fig, ax = plt.subplots(1, 1, figsize=(8, 5)) for r in results: name = r["model_type"] base = name.split("_r")[0] if "_r" in name else name color = COLORS.get(base, "#666") marker = "★" if base == "dpa" else "o" size = 120 if base == "dpa" else 60 ax.scatter(r["flops_ratio"], r["perplexity"], c=color, s=size, zorder=5, label=name if base not in [n.split("_r")[0] for n in [rr["model_type"] for rr in results[:results.index(r)]]] else "") # Connect DPA points dpa_results = [r for r in results if r["model_type"].startswith("dpa_r")] if dpa_results: xs = [r["flops_ratio"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])] ys = [r["perplexity"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])] ax.plot(xs, ys, c=COLORS["dpa"], linewidth=2, alpha=0.5, linestyle="--") ax.set_xlabel("FLOPs (relative to Full Transformer)") ax.set_ylabel("Perplexity ↓") ax.set_title("Decision Point Attention: Accuracy vs Compute") ax.legend(loc="upper right") ax.grid(True, alpha=0.3) Path(save_path).parent.mkdir(parents=True, exist_ok=True) fig.tight_layout() fig.savefig(save_path, bbox_inches="tight") print(f"Saved {save_path}") plt.close() def plot_decision_ratio_ablation(results_path, save_path="figures/ratio_ablation.pdf"): """Ablation: effect of decision point ratio.""" with open(results_path) as f: results = json.load(f) dpa_results = [r for r in results if "dpa" in r["model_type"] and "_r" in r["model_type"]] if not dpa_results: print("No DPA ratio results found") return ratios = [r["decision_ratio"] for r in dpa_results] ppls = [r["perplexity"] for r in dpa_results] flops = [r["flops_ratio"] for r in dpa_results] # Get baselines full_ppl = next((r["perplexity"] for r in results if r["model_type"] == "full_transformer"), None) linear_ppl = next((r["perplexity"] for r in results if r["model_type"] == "pure_linear"), None) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Left: perplexity vs ratio ax1.plot(ratios, ppls, "o-", color=COLORS["dpa"], linewidth=2, markersize=8, label="DPA") if full_ppl: ax1.axhline(full_ppl, color=COLORS["full_transformer"], linestyle="--", label=f"Full Transformer ({full_ppl:.1f})") if linear_ppl: ax1.axhline(linear_ppl, color=COLORS["pure_linear"], linestyle="--", label=f"Pure Linear ({linear_ppl:.1f})") ax1.set_xlabel("Decision Point Ratio") ax1.set_ylabel("Perplexity ↓") ax1.set_title("(a) Quality vs Decision Point Ratio") ax1.legend() ax1.grid(True, alpha=0.3) # Right: FLOPs vs ratio ax2.plot(ratios, flops, "s-", color=COLORS["dpa"], linewidth=2, markersize=8) ax2.axhline(1.0, color=COLORS["full_transformer"], linestyle="--", label="Full Transformer (1.0x)") ax2.set_xlabel("Decision Point Ratio") ax2.set_ylabel("FLOPs (relative)") ax2.set_title("(b) Compute Cost vs Decision Point Ratio") ax2.legend() ax2.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(save_path, bbox_inches="tight") print(f"Saved {save_path}") plt.close() def plot_trajectory_analysis(traj_path, save_path="figures/trajectory_analysis.pdf"): """Visualize decision points in agent trajectories.""" with open(traj_path) as f: trajectories = json.load(f) ratios = [t["decision_ratio"] for t in trajectories] step_types = {} for t in trajectories: for step in t["steps"]: role = step["role"] step_types.setdefault(role, {"dp": 0, "routine": 0}) if step["is_decision_point"]: step_types[role]["dp"] += 1 else: step_types[role]["routine"] += 1 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Left: distribution of decision ratios ax1.hist(ratios, bins=30, color=COLORS["dpa"], alpha=0.7, edgecolor="white") ax1.axvline(np.mean(ratios), color="red", linestyle="--", label=f"Mean: {np.mean(ratios):.1%}") ax1.set_xlabel("Decision Point Ratio") ax1.set_ylabel("Count") ax1.set_title("(a) Distribution of Decision Ratios") ax1.legend() # Right: decision points by step type roles = list(step_types.keys()) dp_counts = [step_types[r]["dp"] for r in roles] routine_counts = [step_types[r]["routine"] for r in roles] x = np.arange(len(roles)) ax2.bar(x - 0.2, dp_counts, 0.4, label="Decision Point", color=COLORS["dpa"]) ax2.bar(x + 0.2, routine_counts, 0.4, label="Routine", color="#ccc") ax2.set_xticks(x) ax2.set_xticklabels(roles, rotation=30) ax2.set_ylabel("Count") ax2.set_title("(b) Decision Points by Step Type") ax2.legend() fig.tight_layout() fig.savefig(save_path, bbox_inches="tight") print(f"Saved {save_path}") plt.close() if __name__ == "__main__": import sys if len(sys.argv) > 1: plot_accuracy_vs_flops(sys.argv[1]) plot_decision_ratio_ablation(sys.argv[1])