laguna-martini / scripts /plot_pruning_insights.py
nikgeo's picture
Compare importance-ranked and random pruning quality curves
9c0900a verified
Raw
History Blame Contribute Delete
6.77 kB
#!/usr/bin/env python3
"""Plot Laguna Martini pruning insights from published artifacts."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--summary-path", type=Path, default=Path("results/summary.json"))
parser.add_argument(
"--mask-path",
type=Path,
default=Path("artifacts/group_keep_mask_25pct.npy"),
)
parser.add_argument(
"--scores-path",
type=Path,
default=Path("artifacts/atomic_scores.npy"),
)
parser.add_argument("--output-dir", type=Path, default=Path("assets"))
return parser.parse_args()
def configure_style() -> None:
sns.set_theme(style="whitegrid", context="talk")
plt.rcParams.update(
{
"figure.facecolor": "#081522",
"axes.facecolor": "#102235",
"axes.edgecolor": "#7ed9e6",
"axes.labelcolor": "#e6fbff",
"axes.titlecolor": "#e6fbff",
"grid.color": "#34536a",
"text.color": "#e6fbff",
"xtick.color": "#cceff4",
"ytick.color": "#cceff4",
}
)
def plot_pruning_curve(summary: dict, output_dir: Path) -> None:
rows = summary["native_group_loss_sweep"]
baseline = summary["native_group_loss_sweep_baseline"]["perplexity"]
x = np.array([row["groups_pruned_ratio"] * 100 for row in rows])
y = np.array([(row["perplexity"] / baseline - 1) * 100 for row in rows])
random_controls = [
summary["random_native_group_10pct_control"],
summary["random_native_group_25pct_control"],
]
random_x = np.array([0, *[row["groups_pruned_ratio"] * 100 for row in random_controls]])
random_y = np.array([0, *[(row["perplexity"] / baseline - 1) * 100 for row in random_controls]])
fig, ax = plt.subplots(figsize=(10, 5.6))
sns.lineplot(
x=x,
y=y,
marker="o",
linewidth=3,
markersize=10,
color="#78dce8",
label="HEAPr-style importance ranking",
ax=ax,
)
sns.lineplot(
x=random_x,
y=random_y,
marker="o",
linewidth=2.5,
markersize=9,
linestyle="--",
color="#f5a65b",
label="Random control",
ax=ax,
)
selected = int(np.flatnonzero(x == 10)[0])
ax.scatter([x[selected]], [y[selected]], s=180, color="#b7d36b", zorder=4)
ax.annotate(
"Recommended operating point\n10% pruning, +1.66% perplexity",
xy=(x[selected], y[selected]),
xytext=(48, 34),
textcoords="offset points",
color="#e6fbff",
fontsize=12,
arrowprops={"arrowstyle": "->", "color": "#b7d36b"},
)
ax.set(
title="Importance Ranking Outperforms Random Pruning",
xlabel="Structured MoE parameters identified for removal (%)",
ylabel="Perplexity increase vs. BF16 baseline (%)",
)
ax.set_xlim(-1, 42)
ax.set_ylim(-1, 44)
ax.legend(frameon=False)
sns.despine(ax=ax)
fig.tight_layout()
fig.savefig(output_dir / "pruning_quality_curve.png", dpi=180)
plt.close(fig)
def plot_atomic_importance_by_layer(scores: np.ndarray, output_dir: Path) -> None:
positive_scores = scores[scores > 0]
score_floor = max(float(positive_scores.min()), np.finfo(scores.dtype).tiny)
log_scores = np.log10(np.clip(scores, score_floor, None))
layers = np.arange(scores.shape[0])
medians = np.median(log_scores, axis=(1, 2))
fig, ax = plt.subplots(figsize=(12, 5.8))
sns.boxplot(
data=[layer.reshape(-1) for layer in log_scores],
color="#78dce8",
fliersize=0,
linewidth=0.8,
width=0.72,
ax=ax,
)
sns.lineplot(
x=layers,
y=medians,
color="#b7d36b",
linewidth=2.5,
marker="o",
markersize=4,
label="Layer median",
ax=ax,
)
ax.set(
title="Later Layers Have Lower-Importance Atomic Experts",
xlabel="Sparse layer index",
ylabel="log10(atomic importance score)",
)
ax.set_ylim(np.quantile(log_scores, 0.005), np.quantile(log_scores, 0.995))
ax.tick_params(axis="x", labelsize=9)
ax.legend(frameon=False)
sns.despine(ax=ax)
fig.tight_layout()
fig.savefig(output_dir / "atomic_importance_by_layer.png", dpi=180)
plt.close(fig)
def plot_gsm8k_curve(summary: dict, output_dir: Path) -> None:
gsm8k = summary["evaluation"]["gsm8k_cot_8shot_first_10pct"]
x = np.array([0, 10, 20, 25])
strict = np.array(
[
gsm8k["baseline_strict_exact_match"],
gsm8k["native_group_10pct_strict_exact_match"],
gsm8k["native_group_20pct_strict_exact_match"],
gsm8k["native_group_25pct_strict_exact_match"],
]
)
flexible = np.array(
[
gsm8k["baseline_flexible_exact_match"],
gsm8k["native_group_10pct_flexible_exact_match"],
gsm8k["native_group_20pct_flexible_exact_match"],
gsm8k["native_group_25pct_flexible_exact_match"],
]
)
fig, ax = plt.subplots(figsize=(10, 5.6))
sns.lineplot(
x=x,
y=strict,
marker="o",
linewidth=3,
markersize=10,
color="#78dce8",
label="Strict exact match",
ax=ax,
)
sns.lineplot(
x=x,
y=flexible,
marker="o",
linewidth=3,
markersize=10,
color="#b7d36b",
label="Flexible exact match",
ax=ax,
)
ax.set(
title="GSM8K-CoT Is Increasingly Sensitive to Pruning",
xlabel="Structured MoE parameters identified for removal (%)",
ylabel="Exact match on paired first-10% subset",
)
ax.set_xticks(x)
ax.set_ylim(0.62, 0.98)
ax.legend(frameon=False)
sns.despine(ax=ax)
fig.tight_layout()
fig.savefig(output_dir / "gsm8k_pruning_curve.png", dpi=180)
plt.close(fig)
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
summary = json.loads(args.summary_path.read_text())
mask = np.load(args.mask_path)
scores = np.load(args.scores_path)
if mask.shape != (39, 256, 8):
raise ValueError(f"unexpected mask shape: {mask.shape}")
if scores.shape != (39, 256, 512):
raise ValueError(f"unexpected scores shape: {scores.shape}")
configure_style()
plot_pruning_curve(summary, args.output_dir)
plot_atomic_importance_by_layer(scores, args.output_dir)
plot_gsm8k_curve(summary, args.output_dir)
if __name__ == "__main__":
main()