#!/usr/bin/env python3 """Plot Laguna Martini pruning insights from published artifacts.""" from __future__ import annotations import argparse import json from pathlib import Path import matplotlib.pyplot as plt import numpy as np import seaborn as sns def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--summary-path", type=Path, default=Path("results/summary.json")) parser.add_argument( "--mask-path", type=Path, default=Path("artifacts/group_keep_mask_25pct.npy"), ) parser.add_argument( "--scores-path", type=Path, default=Path("artifacts/atomic_scores.npy"), ) parser.add_argument("--output-dir", type=Path, default=Path("assets")) return parser.parse_args() def configure_style() -> None: sns.set_theme(style="whitegrid", context="talk") plt.rcParams.update( { "figure.facecolor": "#081522", "axes.facecolor": "#102235", "axes.edgecolor": "#7ed9e6", "axes.labelcolor": "#e6fbff", "axes.titlecolor": "#e6fbff", "grid.color": "#34536a", "text.color": "#e6fbff", "xtick.color": "#cceff4", "ytick.color": "#cceff4", } ) def plot_pruning_curve(summary: dict, output_dir: Path) -> None: rows = summary["native_group_loss_sweep"] baseline = summary["native_group_loss_sweep_baseline"]["perplexity"] x = np.array([row["groups_pruned_ratio"] * 100 for row in rows]) y = np.array([(row["perplexity"] / baseline - 1) * 100 for row in rows]) random_controls = [ summary["random_native_group_10pct_control"], summary["random_native_group_25pct_control"], ] random_x = np.array([0, *[row["groups_pruned_ratio"] * 100 for row in random_controls]]) random_y = np.array([0, *[(row["perplexity"] / baseline - 1) * 100 for row in random_controls]]) fig, ax = plt.subplots(figsize=(10, 5.6)) sns.lineplot( x=x, y=y, marker="o", linewidth=3, markersize=10, color="#78dce8", label="HEAPr-style importance ranking", ax=ax, ) sns.lineplot( x=random_x, y=random_y, marker="o", linewidth=2.5, markersize=9, linestyle="--", color="#f5a65b", label="Random control", ax=ax, ) selected = int(np.flatnonzero(x == 10)[0]) ax.scatter([x[selected]], [y[selected]], s=180, color="#b7d36b", zorder=4) ax.annotate( "Recommended operating point\n10% pruning, +1.66% perplexity", xy=(x[selected], y[selected]), xytext=(48, 34), textcoords="offset points", color="#e6fbff", fontsize=12, arrowprops={"arrowstyle": "->", "color": "#b7d36b"}, ) ax.set( title="Importance Ranking Outperforms Random Pruning", xlabel="Structured MoE parameters identified for removal (%)", ylabel="Perplexity increase vs. BF16 baseline (%)", ) ax.set_xlim(-1, 42) ax.set_ylim(-1, 44) ax.legend(frameon=False) sns.despine(ax=ax) fig.tight_layout() fig.savefig(output_dir / "pruning_quality_curve.png", dpi=180) plt.close(fig) def plot_atomic_importance_by_layer(scores: np.ndarray, output_dir: Path) -> None: positive_scores = scores[scores > 0] score_floor = max(float(positive_scores.min()), np.finfo(scores.dtype).tiny) log_scores = np.log10(np.clip(scores, score_floor, None)) layers = np.arange(scores.shape[0]) medians = np.median(log_scores, axis=(1, 2)) fig, ax = plt.subplots(figsize=(12, 5.8)) sns.boxplot( data=[layer.reshape(-1) for layer in log_scores], color="#78dce8", fliersize=0, linewidth=0.8, width=0.72, ax=ax, ) sns.lineplot( x=layers, y=medians, color="#b7d36b", linewidth=2.5, marker="o", markersize=4, label="Layer median", ax=ax, ) ax.set( title="Later Layers Have Lower-Importance Atomic Experts", xlabel="Sparse layer index", ylabel="log10(atomic importance score)", ) ax.set_ylim(np.quantile(log_scores, 0.005), np.quantile(log_scores, 0.995)) ax.tick_params(axis="x", labelsize=9) ax.legend(frameon=False) sns.despine(ax=ax) fig.tight_layout() fig.savefig(output_dir / "atomic_importance_by_layer.png", dpi=180) plt.close(fig) def plot_gsm8k_curve(summary: dict, output_dir: Path) -> None: gsm8k = summary["evaluation"]["gsm8k_cot_8shot_first_10pct"] x = np.array([0, 10, 20, 25]) strict = np.array( [ gsm8k["baseline_strict_exact_match"], gsm8k["native_group_10pct_strict_exact_match"], gsm8k["native_group_20pct_strict_exact_match"], gsm8k["native_group_25pct_strict_exact_match"], ] ) flexible = np.array( [ gsm8k["baseline_flexible_exact_match"], gsm8k["native_group_10pct_flexible_exact_match"], gsm8k["native_group_20pct_flexible_exact_match"], gsm8k["native_group_25pct_flexible_exact_match"], ] ) fig, ax = plt.subplots(figsize=(10, 5.6)) sns.lineplot( x=x, y=strict, marker="o", linewidth=3, markersize=10, color="#78dce8", label="Strict exact match", ax=ax, ) sns.lineplot( x=x, y=flexible, marker="o", linewidth=3, markersize=10, color="#b7d36b", label="Flexible exact match", ax=ax, ) ax.set( title="GSM8K-CoT Is Increasingly Sensitive to Pruning", xlabel="Structured MoE parameters identified for removal (%)", ylabel="Exact match on paired first-10% subset", ) ax.set_xticks(x) ax.set_ylim(0.62, 0.98) ax.legend(frameon=False) sns.despine(ax=ax) fig.tight_layout() fig.savefig(output_dir / "gsm8k_pruning_curve.png", dpi=180) plt.close(fig) def main() -> None: args = parse_args() args.output_dir.mkdir(parents=True, exist_ok=True) summary = json.loads(args.summary_path.read_text()) mask = np.load(args.mask_path) scores = np.load(args.scores_path) if mask.shape != (39, 256, 8): raise ValueError(f"unexpected mask shape: {mask.shape}") if scores.shape != (39, 256, 512): raise ValueError(f"unexpected scores shape: {scores.shape}") configure_style() plot_pruning_curve(summary, args.output_dir) plot_atomic_importance_by_layer(scores, args.output_dir) plot_gsm8k_curve(summary, args.output_dir) if __name__ == "__main__": main()