jasonfan's picture
Upload folder using huggingface_hub
09dd617 verified
"""Publication-quality figures for DPA paper."""
import json
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from pathlib import Path
matplotlib.rcParams.update({
"font.size": 12, "font.family": "serif",
"axes.labelsize": 14, "axes.titlesize": 15,
"xtick.labelsize": 11, "ytick.labelsize": 11,
"legend.fontsize": 10, "figure.dpi": 150,
})
COLORS = {
"full_transformer": "#2196F3",
"pure_linear": "#FF9800",
"uniform_hybrid": "#4CAF50",
"dpa": "#E91E63",
"dpa_fixed": "#9C27B0",
}
def plot_accuracy_vs_flops(results_path, save_path="figures/accuracy_vs_flops.pdf"):
"""Main figure: accuracy vs compute tradeoff."""
with open(results_path) as f:
results = json.load(f)
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
for r in results:
name = r["model_type"]
base = name.split("_r")[0] if "_r" in name else name
color = COLORS.get(base, "#666")
marker = "★" if base == "dpa" else "o"
size = 120 if base == "dpa" else 60
ax.scatter(r["flops_ratio"], r["perplexity"],
c=color, s=size, zorder=5,
label=name if base not in [n.split("_r")[0] for n in [rr["model_type"] for rr in results[:results.index(r)]]] else "")
# Connect DPA points
dpa_results = [r for r in results if r["model_type"].startswith("dpa_r")]
if dpa_results:
xs = [r["flops_ratio"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])]
ys = [r["perplexity"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])]
ax.plot(xs, ys, c=COLORS["dpa"], linewidth=2, alpha=0.5, linestyle="--")
ax.set_xlabel("FLOPs (relative to Full Transformer)")
ax.set_ylabel("Perplexity ↓")
ax.set_title("Decision Point Attention: Accuracy vs Compute")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
def plot_decision_ratio_ablation(results_path, save_path="figures/ratio_ablation.pdf"):
"""Ablation: effect of decision point ratio."""
with open(results_path) as f:
results = json.load(f)
dpa_results = [r for r in results if "dpa" in r["model_type"] and "_r" in r["model_type"]]
if not dpa_results:
print("No DPA ratio results found")
return
ratios = [r["decision_ratio"] for r in dpa_results]
ppls = [r["perplexity"] for r in dpa_results]
flops = [r["flops_ratio"] for r in dpa_results]
# Get baselines
full_ppl = next((r["perplexity"] for r in results if r["model_type"] == "full_transformer"), None)
linear_ppl = next((r["perplexity"] for r in results if r["model_type"] == "pure_linear"), None)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Left: perplexity vs ratio
ax1.plot(ratios, ppls, "o-", color=COLORS["dpa"], linewidth=2, markersize=8, label="DPA")
if full_ppl:
ax1.axhline(full_ppl, color=COLORS["full_transformer"], linestyle="--", label=f"Full Transformer ({full_ppl:.1f})")
if linear_ppl:
ax1.axhline(linear_ppl, color=COLORS["pure_linear"], linestyle="--", label=f"Pure Linear ({linear_ppl:.1f})")
ax1.set_xlabel("Decision Point Ratio")
ax1.set_ylabel("Perplexity ↓")
ax1.set_title("(a) Quality vs Decision Point Ratio")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Right: FLOPs vs ratio
ax2.plot(ratios, flops, "s-", color=COLORS["dpa"], linewidth=2, markersize=8)
ax2.axhline(1.0, color=COLORS["full_transformer"], linestyle="--", label="Full Transformer (1.0x)")
ax2.set_xlabel("Decision Point Ratio")
ax2.set_ylabel("FLOPs (relative)")
ax2.set_title("(b) Compute Cost vs Decision Point Ratio")
ax2.legend()
ax2.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
def plot_trajectory_analysis(traj_path, save_path="figures/trajectory_analysis.pdf"):
"""Visualize decision points in agent trajectories."""
with open(traj_path) as f:
trajectories = json.load(f)
ratios = [t["decision_ratio"] for t in trajectories]
step_types = {}
for t in trajectories:
for step in t["steps"]:
role = step["role"]
step_types.setdefault(role, {"dp": 0, "routine": 0})
if step["is_decision_point"]:
step_types[role]["dp"] += 1
else:
step_types[role]["routine"] += 1
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Left: distribution of decision ratios
ax1.hist(ratios, bins=30, color=COLORS["dpa"], alpha=0.7, edgecolor="white")
ax1.axvline(np.mean(ratios), color="red", linestyle="--", label=f"Mean: {np.mean(ratios):.1%}")
ax1.set_xlabel("Decision Point Ratio")
ax1.set_ylabel("Count")
ax1.set_title("(a) Distribution of Decision Ratios")
ax1.legend()
# Right: decision points by step type
roles = list(step_types.keys())
dp_counts = [step_types[r]["dp"] for r in roles]
routine_counts = [step_types[r]["routine"] for r in roles]
x = np.arange(len(roles))
ax2.bar(x - 0.2, dp_counts, 0.4, label="Decision Point", color=COLORS["dpa"])
ax2.bar(x + 0.2, routine_counts, 0.4, label="Routine", color="#ccc")
ax2.set_xticks(x)
ax2.set_xticklabels(roles, rotation=30)
ax2.set_ylabel("Count")
ax2.set_title("(b) Decision Points by Step Type")
ax2.legend()
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
plot_accuracy_vs_flops(sys.argv[1])
plot_decision_ratio_ablation(sys.argv[1])