| |
| """Plot Laguna Martini pruning insights from published artifacts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import seaborn as sns |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--summary-path", type=Path, default=Path("results/summary.json")) |
| parser.add_argument( |
| "--mask-path", |
| type=Path, |
| default=Path("artifacts/group_keep_mask_25pct.npy"), |
| ) |
| parser.add_argument( |
| "--scores-path", |
| type=Path, |
| default=Path("artifacts/atomic_scores.npy"), |
| ) |
| parser.add_argument("--output-dir", type=Path, default=Path("assets")) |
| return parser.parse_args() |
|
|
|
|
| def configure_style() -> None: |
| sns.set_theme(style="whitegrid", context="talk") |
| plt.rcParams.update( |
| { |
| "figure.facecolor": "#081522", |
| "axes.facecolor": "#102235", |
| "axes.edgecolor": "#7ed9e6", |
| "axes.labelcolor": "#e6fbff", |
| "axes.titlecolor": "#e6fbff", |
| "grid.color": "#34536a", |
| "text.color": "#e6fbff", |
| "xtick.color": "#cceff4", |
| "ytick.color": "#cceff4", |
| } |
| ) |
|
|
|
|
| def plot_pruning_curve(summary: dict, output_dir: Path) -> None: |
| rows = summary["native_group_loss_sweep"] |
| baseline = summary["native_group_loss_sweep_baseline"]["perplexity"] |
| x = np.array([row["groups_pruned_ratio"] * 100 for row in rows]) |
| y = np.array([(row["perplexity"] / baseline - 1) * 100 for row in rows]) |
| random_controls = [ |
| summary["random_native_group_10pct_control"], |
| summary["random_native_group_25pct_control"], |
| ] |
| random_x = np.array([0, *[row["groups_pruned_ratio"] * 100 for row in random_controls]]) |
| random_y = np.array([0, *[(row["perplexity"] / baseline - 1) * 100 for row in random_controls]]) |
|
|
| fig, ax = plt.subplots(figsize=(10, 5.6)) |
| sns.lineplot( |
| x=x, |
| y=y, |
| marker="o", |
| linewidth=3, |
| markersize=10, |
| color="#78dce8", |
| label="HEAPr-style importance ranking", |
| ax=ax, |
| ) |
| sns.lineplot( |
| x=random_x, |
| y=random_y, |
| marker="o", |
| linewidth=2.5, |
| markersize=9, |
| linestyle="--", |
| color="#f5a65b", |
| label="Random control", |
| ax=ax, |
| ) |
| selected = int(np.flatnonzero(x == 10)[0]) |
| ax.scatter([x[selected]], [y[selected]], s=180, color="#b7d36b", zorder=4) |
| ax.annotate( |
| "Recommended operating point\n10% pruning, +1.66% perplexity", |
| xy=(x[selected], y[selected]), |
| xytext=(48, 34), |
| textcoords="offset points", |
| color="#e6fbff", |
| fontsize=12, |
| arrowprops={"arrowstyle": "->", "color": "#b7d36b"}, |
| ) |
| ax.set( |
| title="Importance Ranking Outperforms Random Pruning", |
| xlabel="Structured MoE parameters identified for removal (%)", |
| ylabel="Perplexity increase vs. BF16 baseline (%)", |
| ) |
| ax.set_xlim(-1, 42) |
| ax.set_ylim(-1, 44) |
| ax.legend(frameon=False) |
| sns.despine(ax=ax) |
| fig.tight_layout() |
| fig.savefig(output_dir / "pruning_quality_curve.png", dpi=180) |
| plt.close(fig) |
|
|
|
|
| def plot_atomic_importance_by_layer(scores: np.ndarray, output_dir: Path) -> None: |
| positive_scores = scores[scores > 0] |
| score_floor = max(float(positive_scores.min()), np.finfo(scores.dtype).tiny) |
| log_scores = np.log10(np.clip(scores, score_floor, None)) |
| layers = np.arange(scores.shape[0]) |
| medians = np.median(log_scores, axis=(1, 2)) |
|
|
| fig, ax = plt.subplots(figsize=(12, 5.8)) |
| sns.boxplot( |
| data=[layer.reshape(-1) for layer in log_scores], |
| color="#78dce8", |
| fliersize=0, |
| linewidth=0.8, |
| width=0.72, |
| ax=ax, |
| ) |
| sns.lineplot( |
| x=layers, |
| y=medians, |
| color="#b7d36b", |
| linewidth=2.5, |
| marker="o", |
| markersize=4, |
| label="Layer median", |
| ax=ax, |
| ) |
| ax.set( |
| title="Later Layers Have Lower-Importance Atomic Experts", |
| xlabel="Sparse layer index", |
| ylabel="log10(atomic importance score)", |
| ) |
| ax.set_ylim(np.quantile(log_scores, 0.005), np.quantile(log_scores, 0.995)) |
| ax.tick_params(axis="x", labelsize=9) |
| ax.legend(frameon=False) |
| sns.despine(ax=ax) |
| fig.tight_layout() |
| fig.savefig(output_dir / "atomic_importance_by_layer.png", dpi=180) |
| plt.close(fig) |
|
|
|
|
| def plot_gsm8k_curve(summary: dict, output_dir: Path) -> None: |
| gsm8k = summary["evaluation"]["gsm8k_cot_8shot_first_10pct"] |
| x = np.array([0, 10, 20, 25]) |
| strict = np.array( |
| [ |
| gsm8k["baseline_strict_exact_match"], |
| gsm8k["native_group_10pct_strict_exact_match"], |
| gsm8k["native_group_20pct_strict_exact_match"], |
| gsm8k["native_group_25pct_strict_exact_match"], |
| ] |
| ) |
| flexible = np.array( |
| [ |
| gsm8k["baseline_flexible_exact_match"], |
| gsm8k["native_group_10pct_flexible_exact_match"], |
| gsm8k["native_group_20pct_flexible_exact_match"], |
| gsm8k["native_group_25pct_flexible_exact_match"], |
| ] |
| ) |
|
|
| fig, ax = plt.subplots(figsize=(10, 5.6)) |
| sns.lineplot( |
| x=x, |
| y=strict, |
| marker="o", |
| linewidth=3, |
| markersize=10, |
| color="#78dce8", |
| label="Strict exact match", |
| ax=ax, |
| ) |
| sns.lineplot( |
| x=x, |
| y=flexible, |
| marker="o", |
| linewidth=3, |
| markersize=10, |
| color="#b7d36b", |
| label="Flexible exact match", |
| ax=ax, |
| ) |
| ax.set( |
| title="GSM8K-CoT Is Increasingly Sensitive to Pruning", |
| xlabel="Structured MoE parameters identified for removal (%)", |
| ylabel="Exact match on paired first-10% subset", |
| ) |
| ax.set_xticks(x) |
| ax.set_ylim(0.62, 0.98) |
| ax.legend(frameon=False) |
| sns.despine(ax=ax) |
| fig.tight_layout() |
| fig.savefig(output_dir / "gsm8k_pruning_curve.png", dpi=180) |
| plt.close(fig) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| summary = json.loads(args.summary_path.read_text()) |
| mask = np.load(args.mask_path) |
| scores = np.load(args.scores_path) |
| if mask.shape != (39, 256, 8): |
| raise ValueError(f"unexpected mask shape: {mask.shape}") |
| if scores.shape != (39, 256, 512): |
| raise ValueError(f"unexpected scores shape: {scores.shape}") |
|
|
| configure_style() |
| plot_pruning_curve(summary, args.output_dir) |
| plot_atomic_importance_by_layer(scores, args.output_dir) |
| plot_gsm8k_curve(summary, args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|