File size: 5,882 Bytes
09dd617 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Publication-quality figures for DPA paper."""
import json
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from pathlib import Path
matplotlib.rcParams.update({
"font.size": 12, "font.family": "serif",
"axes.labelsize": 14, "axes.titlesize": 15,
"xtick.labelsize": 11, "ytick.labelsize": 11,
"legend.fontsize": 10, "figure.dpi": 150,
})
COLORS = {
"full_transformer": "#2196F3",
"pure_linear": "#FF9800",
"uniform_hybrid": "#4CAF50",
"dpa": "#E91E63",
"dpa_fixed": "#9C27B0",
}
def plot_accuracy_vs_flops(results_path, save_path="figures/accuracy_vs_flops.pdf"):
"""Main figure: accuracy vs compute tradeoff."""
with open(results_path) as f:
results = json.load(f)
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
for r in results:
name = r["model_type"]
base = name.split("_r")[0] if "_r" in name else name
color = COLORS.get(base, "#666")
marker = "★" if base == "dpa" else "o"
size = 120 if base == "dpa" else 60
ax.scatter(r["flops_ratio"], r["perplexity"],
c=color, s=size, zorder=5,
label=name if base not in [n.split("_r")[0] for n in [rr["model_type"] for rr in results[:results.index(r)]]] else "")
# Connect DPA points
dpa_results = [r for r in results if r["model_type"].startswith("dpa_r")]
if dpa_results:
xs = [r["flops_ratio"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])]
ys = [r["perplexity"] for r in sorted(dpa_results, key=lambda x: x["flops_ratio"])]
ax.plot(xs, ys, c=COLORS["dpa"], linewidth=2, alpha=0.5, linestyle="--")
ax.set_xlabel("FLOPs (relative to Full Transformer)")
ax.set_ylabel("Perplexity ↓")
ax.set_title("Decision Point Attention: Accuracy vs Compute")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
def plot_decision_ratio_ablation(results_path, save_path="figures/ratio_ablation.pdf"):
"""Ablation: effect of decision point ratio."""
with open(results_path) as f:
results = json.load(f)
dpa_results = [r for r in results if "dpa" in r["model_type"] and "_r" in r["model_type"]]
if not dpa_results:
print("No DPA ratio results found")
return
ratios = [r["decision_ratio"] for r in dpa_results]
ppls = [r["perplexity"] for r in dpa_results]
flops = [r["flops_ratio"] for r in dpa_results]
# Get baselines
full_ppl = next((r["perplexity"] for r in results if r["model_type"] == "full_transformer"), None)
linear_ppl = next((r["perplexity"] for r in results if r["model_type"] == "pure_linear"), None)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Left: perplexity vs ratio
ax1.plot(ratios, ppls, "o-", color=COLORS["dpa"], linewidth=2, markersize=8, label="DPA")
if full_ppl:
ax1.axhline(full_ppl, color=COLORS["full_transformer"], linestyle="--", label=f"Full Transformer ({full_ppl:.1f})")
if linear_ppl:
ax1.axhline(linear_ppl, color=COLORS["pure_linear"], linestyle="--", label=f"Pure Linear ({linear_ppl:.1f})")
ax1.set_xlabel("Decision Point Ratio")
ax1.set_ylabel("Perplexity ↓")
ax1.set_title("(a) Quality vs Decision Point Ratio")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Right: FLOPs vs ratio
ax2.plot(ratios, flops, "s-", color=COLORS["dpa"], linewidth=2, markersize=8)
ax2.axhline(1.0, color=COLORS["full_transformer"], linestyle="--", label="Full Transformer (1.0x)")
ax2.set_xlabel("Decision Point Ratio")
ax2.set_ylabel("FLOPs (relative)")
ax2.set_title("(b) Compute Cost vs Decision Point Ratio")
ax2.legend()
ax2.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
def plot_trajectory_analysis(traj_path, save_path="figures/trajectory_analysis.pdf"):
"""Visualize decision points in agent trajectories."""
with open(traj_path) as f:
trajectories = json.load(f)
ratios = [t["decision_ratio"] for t in trajectories]
step_types = {}
for t in trajectories:
for step in t["steps"]:
role = step["role"]
step_types.setdefault(role, {"dp": 0, "routine": 0})
if step["is_decision_point"]:
step_types[role]["dp"] += 1
else:
step_types[role]["routine"] += 1
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Left: distribution of decision ratios
ax1.hist(ratios, bins=30, color=COLORS["dpa"], alpha=0.7, edgecolor="white")
ax1.axvline(np.mean(ratios), color="red", linestyle="--", label=f"Mean: {np.mean(ratios):.1%}")
ax1.set_xlabel("Decision Point Ratio")
ax1.set_ylabel("Count")
ax1.set_title("(a) Distribution of Decision Ratios")
ax1.legend()
# Right: decision points by step type
roles = list(step_types.keys())
dp_counts = [step_types[r]["dp"] for r in roles]
routine_counts = [step_types[r]["routine"] for r in roles]
x = np.arange(len(roles))
ax2.bar(x - 0.2, dp_counts, 0.4, label="Decision Point", color=COLORS["dpa"])
ax2.bar(x + 0.2, routine_counts, 0.4, label="Routine", color="#ccc")
ax2.set_xticks(x)
ax2.set_xticklabels(roles, rotation=30)
ax2.set_ylabel("Count")
ax2.set_title("(b) Decision Points by Step Type")
ax2.legend()
fig.tight_layout()
fig.savefig(save_path, bbox_inches="tight")
print(f"Saved {save_path}")
plt.close()
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
plot_accuracy_vs_flops(sys.argv[1])
plot_decision_ratio_ablation(sys.argv[1])
|