| """Publication-quality figures for DPA paper.""" |
|
|
| import json |
| import matplotlib.pyplot as plt |
| import matplotlib |
| import numpy as np |
| from pathlib import Path |
|
|
| matplotlib.rcParams.update({ |
| "font.size": 12, "font.family": "serif", |
| "axes.labelsize": 14, "axes.titlesize": 15, |
| "xtick.labelsize": 11, "ytick.labelsize": 11, |
| "legend.fontsize": 10, "figure.dpi": 150, |
| }) |
|
|
| COLORS = { |
| "full_transformer": "#2196F3", |
| "pure_linear": "#FF9800", |
| "uniform_hybrid": "#4CAF50", |
| "dpa": "#E91E63", |
| "dpa_fixed": "#9C27B0", |
| } |
|
|
|
|
| def plot_accuracy_vs_flops(results_path, save_path="figures/accuracy_vs_flops.pdf"): |
| """Main figure: accuracy vs compute tradeoff.""" |
| with open(results_path) as f: |
| results = json.load(f) |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(8, 5)) |
|
|
| for r in results: |
| name = r["model_type"] |
| base = name.split("_r")[0] if "_r" in name else name |
| color = COLORS.get(base, "#666") |
| marker = "★" if base == "dpa" else "o" |
| size = 120 if base == "dpa" else 60 |
|
|
| ax.scatter(r["flops_ratio"], r["perplexity"], |
| c=color, s=size, zorder=5, |
| label=name if base not in [n.split("_r")[0] for n in [rr["model_type"] for rr in results[:results.index(r)]]] else "") |
|
|
| |
| dpa_results = [r for r in results if r["model_type"].startswith("dpa_r")] |
| if dpa_results: |
| xs = [r["flops_ratio"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])] |
| ys = [r["perplexity"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])] |
| ax.plot(xs, ys, c=COLORS["dpa"], linewidth=2, alpha=0.5, linestyle="--") |
|
|
| ax.set_xlabel("FLOPs (relative to Full Transformer)") |
| ax.set_ylabel("Perplexity ↓") |
| ax.set_title("Decision Point Attention: Accuracy vs Compute") |
| ax.legend(loc="upper right") |
| ax.grid(True, alpha=0.3) |
|
|
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| fig.tight_layout() |
| fig.savefig(save_path, bbox_inches="tight") |
| print(f"Saved {save_path}") |
| plt.close() |
|
|
|
|
| def plot_decision_ratio_ablation(results_path, save_path="figures/ratio_ablation.pdf"): |
| """Ablation: effect of decision point ratio.""" |
| with open(results_path) as f: |
| results = json.load(f) |
|
|
| dpa_results = [r for r in results if "dpa" in r["model_type"] and "_r" in r["model_type"]] |
| if not dpa_results: |
| print("No DPA ratio results found") |
| return |
|
|
| ratios = [r["decision_ratio"] for r in dpa_results] |
| ppls = [r["perplexity"] for r in dpa_results] |
| flops = [r["flops_ratio"] for r in dpa_results] |
|
|
| |
| full_ppl = next((r["perplexity"] for r in results if r["model_type"] == "full_transformer"), None) |
| linear_ppl = next((r["perplexity"] for r in results if r["model_type"] == "pure_linear"), None) |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
| |
| ax1.plot(ratios, ppls, "o-", color=COLORS["dpa"], linewidth=2, markersize=8, label="DPA") |
| if full_ppl: |
| ax1.axhline(full_ppl, color=COLORS["full_transformer"], linestyle="--", label=f"Full Transformer ({full_ppl:.1f})") |
| if linear_ppl: |
| ax1.axhline(linear_ppl, color=COLORS["pure_linear"], linestyle="--", label=f"Pure Linear ({linear_ppl:.1f})") |
| ax1.set_xlabel("Decision Point Ratio") |
| ax1.set_ylabel("Perplexity ↓") |
| ax1.set_title("(a) Quality vs Decision Point Ratio") |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
|
|
| |
| ax2.plot(ratios, flops, "s-", color=COLORS["dpa"], linewidth=2, markersize=8) |
| ax2.axhline(1.0, color=COLORS["full_transformer"], linestyle="--", label="Full Transformer (1.0x)") |
| ax2.set_xlabel("Decision Point Ratio") |
| ax2.set_ylabel("FLOPs (relative)") |
| ax2.set_title("(b) Compute Cost vs Decision Point Ratio") |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
|
|
| fig.tight_layout() |
| fig.savefig(save_path, bbox_inches="tight") |
| print(f"Saved {save_path}") |
| plt.close() |
|
|
|
|
| def plot_trajectory_analysis(traj_path, save_path="figures/trajectory_analysis.pdf"): |
| """Visualize decision points in agent trajectories.""" |
| with open(traj_path) as f: |
| trajectories = json.load(f) |
|
|
| ratios = [t["decision_ratio"] for t in trajectories] |
| step_types = {} |
| for t in trajectories: |
| for step in t["steps"]: |
| role = step["role"] |
| step_types.setdefault(role, {"dp": 0, "routine": 0}) |
| if step["is_decision_point"]: |
| step_types[role]["dp"] += 1 |
| else: |
| step_types[role]["routine"] += 1 |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
| |
| ax1.hist(ratios, bins=30, color=COLORS["dpa"], alpha=0.7, edgecolor="white") |
| ax1.axvline(np.mean(ratios), color="red", linestyle="--", label=f"Mean: {np.mean(ratios):.1%}") |
| ax1.set_xlabel("Decision Point Ratio") |
| ax1.set_ylabel("Count") |
| ax1.set_title("(a) Distribution of Decision Ratios") |
| ax1.legend() |
|
|
| |
| roles = list(step_types.keys()) |
| dp_counts = [step_types[r]["dp"] for r in roles] |
| routine_counts = [step_types[r]["routine"] for r in roles] |
|
|
| x = np.arange(len(roles)) |
| ax2.bar(x - 0.2, dp_counts, 0.4, label="Decision Point", color=COLORS["dpa"]) |
| ax2.bar(x + 0.2, routine_counts, 0.4, label="Routine", color="#ccc") |
| ax2.set_xticks(x) |
| ax2.set_xticklabels(roles, rotation=30) |
| ax2.set_ylabel("Count") |
| ax2.set_title("(b) Decision Points by Step Type") |
| ax2.legend() |
|
|
| fig.tight_layout() |
| fig.savefig(save_path, bbox_inches="tight") |
| print(f"Saved {save_path}") |
| plt.close() |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| if len(sys.argv) > 1: |
| plot_accuracy_vs_flops(sys.argv[1]) |
| plot_decision_ratio_ablation(sys.argv[1]) |
|
|