Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Plot head-head correlation results from saved .npz / .npy files. | |
| Usage: | |
| python scripts/plot_correlations.py --data corr_out --model gpt2 | |
| python scripts/plot_correlations.py --data corr_out --model gpt2 --metrics frob_cosine jensen_shannon | |
| python scripts/plot_correlations.py --data corr_out --model gpt2 --out figures/correlations | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import MaxNLocator | |
| # ββ Style ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FONT_SIZE = 11 | |
| TITLE_SIZE = 13 | |
| DPI = 200 | |
| plt.rcParams.update({ | |
| "font.size": FONT_SIZE, | |
| "axes.titlesize": TITLE_SIZE, | |
| "figure.dpi": DPI, | |
| }) | |
| # ββ Data loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_results(data_dir, model, revision="main", weight_type="W_QK"): | |
| """Load all saved correlation data for a model run.""" | |
| prefix = f"{model}_{revision}_{weight_type}" | |
| with open(os.path.join(data_dir, f"{prefix}_metadata.json")) as f: | |
| metadata = json.load(f) | |
| with open(os.path.join(data_dir, f"{prefix}_summary.json")) as f: | |
| summary = json.load(f) | |
| Q_data = np.load(os.path.join(data_dir, f"{prefix}_Q.npz")) | |
| Q = {k.replace("Q_", ""): Q_data[k] for k in Q_data.files} | |
| eigenvalues = {} | |
| P_Q = {} | |
| block_means = {} | |
| for m in metadata["metrics"]: | |
| eig_path = os.path.join(data_dir, f"{prefix}_{m}_eigenvalues.npy") | |
| if os.path.exists(eig_path): | |
| eigenvalues[m] = np.load(eig_path) | |
| pq_path = os.path.join(data_dir, f"{prefix}_{m}_P_Q.npy") | |
| if os.path.exists(pq_path): | |
| P_Q[m] = np.load(pq_path) | |
| bm_path = os.path.join(data_dir, f"{prefix}_{m}_block_means.npy") | |
| if os.path.exists(bm_path): | |
| block_means[m] = np.load(bm_path) | |
| keys = [tuple(k) for k in metadata["head_index"]] | |
| return { | |
| "Q": Q, "summary": summary, "eigenvalues": eigenvalues, | |
| "P_Q": P_Q, "block_means": block_means, | |
| "metadata": metadata, "keys": keys, | |
| } | |
| # ββ Plot functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _layer_boundaries(keys): | |
| layers = [k[0] for k in keys] | |
| bounds = [] | |
| for i in range(1, len(layers)): | |
| if layers[i] != layers[i - 1]: | |
| bounds.append(i) | |
| return bounds | |
| def _metric_display(name): | |
| return { | |
| "frob_cosine": "Frobenius cosine similarity", | |
| "symmetric_kl": "Symmetric KL divergence (KDE)", | |
| "jensen_shannon": "Jensen-Shannon divergence (KDE)", | |
| "hist_symmetric_kl": "Symmetric KL divergence (histogram)", | |
| "hist_jensen_shannon": "Jensen-Shannon divergence (histogram)", | |
| "two_point": "Two-point function $\\langle W_1 W_2 \\rangle$", | |
| "connected_corr": "Connected correlation $\\langle W_1 W_2 \\rangle - \\langle W_1 \\rangle \\langle W_2 \\rangle$", | |
| "pearson_corr": "Pearson correlation (normalized connected)", | |
| }.get(name, name) | |
| def _is_divergence(name): | |
| return name in ("symmetric_kl", "jensen_shannon", | |
| "hist_symmetric_kl", "hist_jensen_shannon") | |
| def _is_correlation_metric(name): | |
| """Metrics where a diverging (RdBu) colormap centered on 0 is appropriate.""" | |
| return name in ("frob_cosine", "connected_corr", "pearson_corr", "two_point") | |
| # Canonical 2Γ3 metric ordering: cosine + Pearson (similar shape), | |
| # symmetric KL + connected corr, JS + two-point. | |
| _METRIC_ORDER = [ | |
| "frob_cosine", "pearson_corr", | |
| "symmetric_kl", "connected_corr", | |
| "jensen_shannon", "two_point", | |
| ] | |
| _METRIC_ALT = { | |
| "symmetric_kl": "hist_symmetric_kl", | |
| "jensen_shannon": "hist_jensen_shannon", | |
| } | |
| def _order_metrics(available): | |
| """Return metrics in canonical 2Γ3 order, falling back to hist variants.""" | |
| ordered = [] | |
| for slot in _METRIC_ORDER: | |
| if slot in available: | |
| ordered.append(slot) | |
| elif slot in _METRIC_ALT and _METRIC_ALT[slot] in available: | |
| ordered.append(_METRIC_ALT[slot]) | |
| for m in available: | |
| if m not in ordered: | |
| ordered.append(m) | |
| return ordered | |
| def plot_heatmap(Q, keys, metric_name, model_name, out_dir): | |
| """Single Q_{hh'} heatmap.""" | |
| fig, ax = plt.subplots(figsize=(10, 9)) | |
| bounds = _layer_boundaries(keys) | |
| n = Q.shape[0] | |
| if _is_divergence(metric_name): | |
| cmap = "viridis_r" | |
| im = ax.imshow(Q, cmap=cmap, aspect="equal") | |
| else: | |
| vmax = np.percentile(np.abs(Q), 98) | |
| cmap = "RdBu_r" | |
| im = ax.imshow(Q, cmap=cmap, aspect="equal", vmin=-vmax, vmax=vmax) | |
| for b in bounds: | |
| ax.axhline(b - 0.5, color="white", linewidth=0.5, alpha=0.8) | |
| ax.axvline(b - 0.5, color="white", linewidth=0.5, alpha=0.8) | |
| # layer labels at midpoints | |
| layers = sorted(set(k[0] for k in keys)) | |
| n_per = len(keys) // len(layers) | |
| tick_pos = [l * n_per + n_per // 2 for l in range(len(layers))] | |
| ax.set_xticks(tick_pos) | |
| ax.set_xticklabels([str(l) for l in layers], fontsize=9) | |
| ax.set_yticks(tick_pos) | |
| ax.set_yticklabels([str(l) for l in layers], fontsize=9) | |
| ax.set_xlabel("Layer") | |
| ax.set_ylabel("Layer") | |
| ax.set_title(f"{model_name} β $Q_{{hh'}}$ ({_metric_display(metric_name)})") | |
| fig.colorbar(im, ax=ax, shrink=0.8, label=_metric_display(metric_name)) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_Q_heatmap_{metric_name}.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| def plot_P_Q(P_Q_dict, summary, model_name, out_dir): | |
| """P(Q) overlap distributions in canonical 2Γ3 grid.""" | |
| ordered = _order_metrics(P_Q_dict.keys()) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| axes = axes.flat | |
| for idx, m in enumerate(ordered[:6]): | |
| ax = axes[idx] | |
| vals = P_Q_dict[m] | |
| ax.hist(vals, bins=60, density=True, alpha=0.7, color="#636EFA", | |
| edgecolor="white", linewidth=0.3) | |
| mu = summary[m]["mean_offdiag"] | |
| sigma = summary[m].get("std_offdiag", np.std(vals)) | |
| ax.axvline(mu, color="#EF553B", linestyle="--", linewidth=1.2, | |
| label=f"$\\mu$ = {mu:.3f}, $\\sigma$ = {sigma:.3f}") | |
| ax.set_title(_metric_display(m), fontsize=10) | |
| ax.set_xlabel("$Q$ value") | |
| ax.set_ylabel("density") | |
| ax.set_yscale("log") | |
| ax.legend(fontsize=9) | |
| for idx in range(len(ordered[:6]), 6): | |
| axes[idx].set_visible(False) | |
| fig.suptitle(f"{model_name} β Overlap distributions $P(Q)$", fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_P_Q.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| def plot_eigenvalues(eig_dict, model_name, out_dir): | |
| """Eigenvalue spectra of Q in canonical 2Γ3 grid.""" | |
| ordered = _order_metrics(eig_dict.keys()) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| axes = axes.flat | |
| for idx, m in enumerate(ordered[:6]): | |
| ax = axes[idx] | |
| eigvals = eig_dict[m] | |
| abs_eig = np.sort(np.abs(eigvals))[::-1] | |
| ax.plot(abs_eig, "o-", markersize=3, color="#EF553B") | |
| ax.set_yscale("log") | |
| ax.set_ylabel("$|\\lambda|$") | |
| ax.set_title(_metric_display(m), fontsize=10) | |
| ax.set_xlabel("Index") | |
| ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
| for idx in range(len(ordered[:6]), 6): | |
| axes[idx].set_visible(False) | |
| fig.suptitle(f"{model_name} β Eigenvalues of $Q_{{hh'}}$", fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_Q_eigenvalues.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| def plot_block_means(block_dict, metadata, model_name, out_dir): | |
| """Layer Γ layer block-mean heatmaps.""" | |
| n_layers = metadata["n_layers"] | |
| layers = list(range(n_layers)) | |
| metrics = list(block_dict.keys()) | |
| for m in metrics: | |
| block = block_dict[m] | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| if _is_divergence(m): | |
| im = ax.imshow(block, cmap="viridis_r", aspect="equal") | |
| else: | |
| vmax = np.max(np.abs(block)) | |
| im = ax.imshow(block, cmap="RdBu_r", aspect="equal", | |
| vmin=-vmax, vmax=vmax) | |
| ax.set_xticks(range(n_layers)) | |
| ax.set_yticks(range(n_layers)) | |
| tick_fs = 7 if n_layers > 20 else 9 | |
| tick_rot = 90 if n_layers > 20 else 0 | |
| ax.set_xticklabels(layers, fontsize=tick_fs, rotation=tick_rot) | |
| ax.set_yticklabels(layers, fontsize=tick_fs) | |
| ax.set_xlabel("Layer") | |
| ax.set_ylabel("Layer") | |
| ax.set_title(f"{model_name} β Layer-block means\n{_metric_display(m)}") | |
| fig.colorbar(im, ax=ax, shrink=0.8) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_block_means_{m}.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| def plot_correlation_vs_layer_distance(P_Q_dict, keys, Q_dict, model_name, out_dir): | |
| """Mean |Q| as a function of layer distance |l - l'|, for each metric. | |
| Fixed 2Γ3 grid: top row frob_cosine, KL, JS; bottom row two_point, | |
| connected_corr, pearson_corr. Unused panels hidden. | |
| """ | |
| layers = np.array([k[0] for k in keys]) | |
| n = len(keys) | |
| ordered = _order_metrics(Q_dict.keys()) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| axes = axes.flat | |
| for idx, m in enumerate(ordered[:6]): | |
| ax = axes[idx] | |
| Q = Q_dict[m] | |
| triu_i, triu_j = np.triu_indices(n, k=1) | |
| dists = np.abs(layers[triu_i] - layers[triu_j]) | |
| vals = Q[triu_i, triu_j] | |
| unique_d = np.unique(dists) | |
| means = [np.mean(np.abs(vals[dists == d])) for d in unique_d] | |
| stds = [np.std(vals[dists == d]) for d in unique_d] | |
| ax.errorbar(unique_d, means, yerr=stds, fmt="o-", markersize=4, | |
| capsize=3, color="#636EFA") | |
| ax.set_xlabel("Layer distance $|\\ell - \\ell'|$") | |
| ax.set_ylabel("Mean $|Q|$" if not _is_divergence(m) else "Mean $Q$") | |
| ax.set_title(_metric_display(m), fontsize=10) | |
| ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
| for idx in range(len(ordered[:6]), 6): | |
| axes[idx].set_visible(False) | |
| fig.suptitle(f"{model_name} β Correlation vs. layer distance", fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_corr_vs_layer_distance.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| # ββ Marchenko-Pastur overlay βββββββββββββββββββββββββββββββββββββββββββ | |
| def _mp_density(lam, gamma): | |
| """Marchenko-Pastur density for aspect ratio gamma = N/p.""" | |
| lam_m = (1 - np.sqrt(gamma)) ** 2 | |
| lam_p = (1 + np.sqrt(gamma)) ** 2 | |
| mask = (lam >= lam_m) & (lam <= lam_p) | |
| density = np.zeros_like(lam) | |
| density[mask] = (np.sqrt((lam_p - lam[mask]) * (lam[mask] - lam_m)) | |
| / (2 * np.pi * gamma * lam[mask])) | |
| return density | |
| def compute_Q_eigen_stats(eigvals, gamma): | |
| """Condition number, NPR, and stable rank from Q eigenvalues + MP predictions.""" | |
| eigvals = np.sort(np.real(eigvals))[::-1] | |
| lam_max = eigvals[0] | |
| lam_min = eigvals[-1] | |
| # Measured | |
| cond = lam_max / max(lam_min, 1e-12) | |
| npr = (np.sum(eigvals) ** 2) / (len(eigvals) * np.sum(eigvals ** 2)) | |
| srank = np.sum(eigvals ** 2) / max(lam_max ** 2, 1e-12) | |
| # MP predictions | |
| sg = np.sqrt(gamma) | |
| mp_lam_plus = (1 + sg) ** 2 | |
| mp_lam_minus = (1 - sg) ** 2 | |
| mp_cond = mp_lam_plus / max(mp_lam_minus, 1e-12) | |
| mp_npr = 1 / (1 + gamma) | |
| N = len(eigvals) | |
| mp_srank = N * (1 + gamma) / (1 + sg) ** 4 # E[tr(QΒ²)] / E[Ξ»_max]Β² | |
| return { | |
| "condition_number": cond, | |
| "npr": npr, | |
| "stable_rank": srank, | |
| "mp_condition_number": mp_cond, | |
| "mp_npr": mp_npr, | |
| "mp_stable_rank": mp_srank, | |
| "gamma": gamma, | |
| "N": N, | |
| "lam_max": lam_max, | |
| "lam_min": lam_min, | |
| } | |
| def plot_mp_overlay(Q_frob, metadata, model_name, out_dir): | |
| """Eigenvalue spectrum of Q^(Frob) with Marchenko-Pastur prediction. | |
| Robust version: clamps axis limits, constrains annotations to the | |
| visible canvas, and caps the figure aspect ratio. | |
| """ | |
| N = metadata["n_layers"] * metadata["n_heads"] | |
| d_head = metadata["head_dim"] | |
| p = d_head ** 2 | |
| gamma = N / p | |
| lam_minus = (1 - np.sqrt(gamma)) ** 2 | |
| lam_plus = (1 + np.sqrt(gamma)) ** 2 | |
| eigvals = np.linalg.eigvalsh(Q_frob)[::-1] | |
| n_outliers = int(np.sum(eigvals > lam_plus)) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| # ββ Left: ordered eigenvalues with MP band ββ | |
| idx = np.arange(len(eigvals)) | |
| ax1.semilogy(idx, np.maximum(eigvals, 1e-8), "o-", markersize=3, | |
| color="#EF553B", label=f"{model_name} (trained)", zorder=3) | |
| ax1.axhspan(lam_minus, lam_plus, alpha=0.15, color="#636EFA", | |
| label=f"MP bulk [{lam_minus:.2f}, {lam_plus:.2f}]", zorder=1) | |
| ax1.axhline(lam_plus, color="#636EFA", linestyle="--", linewidth=1, alpha=0.7) | |
| ax1.axhline(lam_minus, color="#636EFA", linestyle="--", linewidth=1, alpha=0.7) | |
| if n_outliers > 0: | |
| # Arrow points at the last (smallest) outlier eigenvalue β | |
| # index n_outliers-1 in the descending-sorted array. | |
| last_outlier_idx = n_outliers - 1 | |
| text_x = min(last_outlier_idx + 8, N * 0.4) | |
| text_y = eigvals[last_outlier_idx] * 1.5 | |
| ax1.annotate( | |
| f"{n_outliers} outlier{'s' if n_outliers > 1 else ''} " | |
| f"above MP edge ($\\lambda_{{max}}$={eigvals[0]:.1f})", | |
| xy=(last_outlier_idx, eigvals[last_outlier_idx]), | |
| xytext=(text_x, text_y), | |
| fontsize=9, color="#636EFA", | |
| arrowprops=dict(arrowstyle="->", color="#636EFA", lw=1), | |
| annotation_clip=True, | |
| ) | |
| ax1.set_xlim(-0.5, N + 0.5) | |
| ax1.set_xlabel("Index") | |
| ax1.set_ylabel("$\\lambda$") | |
| ax1.set_title(f"Eigenvalue spectrum of $Q_{{hh'}}^{{\\mathrm{{(Frob)}}}}$\n" | |
| f"({model_name}, {N} heads, $\\gamma$ = {gamma:.4f})") | |
| ax1.legend(fontsize=9, loc="upper right") | |
| # ββ Right: histogram with MP density overlay ββ | |
| # Clamp histogram range: focus on the MP bulk + modest outlier range | |
| hist_max = min(lam_plus * 4, eigvals.max() * 1.1) | |
| # Keep at least 90 % of eigenvalues visible | |
| sorted_eig = np.sort(eigvals) | |
| p90 = sorted_eig[int(0.9 * len(sorted_eig))] if len(sorted_eig) else 1.0 | |
| hist_max = max(hist_max, p90 * 1.5) | |
| bins = np.linspace(0, hist_max, 60) | |
| ax2.hist(eigvals[eigvals <= hist_max * 1.1], bins=bins, density=True, | |
| alpha=0.6, color="#EF553B", edgecolor="white", linewidth=0.3, | |
| label=f"{model_name} eigenvalues") | |
| lam_grid = np.linspace(0.01, lam_plus * 1.5, 500) | |
| mp_curve = _mp_density(lam_grid, gamma) | |
| ax2.plot(lam_grid, mp_curve, "-", color="#636EFA", linewidth=2.5, | |
| label=f"MP density ($\\gamma$ = {gamma:.4f})") | |
| # Annotate outliers: list all values as text (they often exceed hist range) | |
| outlier_vals = eigvals[eigvals > lam_plus] | |
| if len(outlier_vals) > 0: | |
| # Build a compact label listing the outlier eigenvalues | |
| if len(outlier_vals) <= 5: | |
| val_strs = [f"{v:.1f}" for v in outlier_vals] | |
| else: | |
| val_strs = [f"{v:.1f}" for v in outlier_vals[:4]] + ["..."] | |
| ax2.text( | |
| 0.97, 0.95, | |
| f"{len(outlier_vals)} outlier{'s' if len(outlier_vals) > 1 else ''}" | |
| f" > $\\lambda_+$\n$\\lambda$ = {', '.join(val_strs)}", | |
| transform=ax2.transAxes, fontsize=8, color="#636EFA", | |
| ha="right", va="top", | |
| bbox=dict(boxstyle="round,pad=0.3", fc="#1a1a2a", ec="#636EFA", | |
| alpha=0.8), | |
| ) | |
| ax2.axvline(lam_plus, color="#636EFA", linestyle="--", linewidth=1, alpha=0.7) | |
| ax2.set_xlabel("Eigenvalue $\\lambda$") | |
| ax2.set_ylabel("Density") | |
| ax2.set_title("Eigenvalue distribution vs. MP prediction") | |
| ax2.legend(fontsize=9) | |
| ax2.set_xlim(0, hist_max) | |
| # ββ Stats inset on left panel ββ | |
| stats = compute_Q_eigen_stats(eigvals, gamma) | |
| stats_text = ( | |
| f"{'':>12s} {'Meas':>8s} {'MP':>8s}\n" | |
| f"{'C':>12s} {stats['condition_number']:>8.1f} {stats['mp_condition_number']:>8.2f}\n" | |
| f"{'NPR':>12s} {stats['npr']:>8.3f} {stats['mp_npr']:>8.3f}\n" | |
| f"{'stable rank':>12s} {stats['stable_rank']:>8.1f} {stats['mp_stable_rank']:>8.1f}" | |
| ) | |
| ax1.text( | |
| 0.97, 0.45, stats_text, | |
| transform=ax1.transAxes, fontsize=7.5, fontfamily="monospace", | |
| ha="right", va="top", | |
| bbox=dict(boxstyle="round,pad=0.4", fc="#1a1a2a", ec="#888888", | |
| alpha=0.85), | |
| color="#e0e0e0", | |
| ) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_MP_overlay.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath, stats | |
| # ββ Dominant eigenvector visualization βββββββββββββββββββββββββββββββββ | |
| def plot_dominant_eigenvector(Q_frob, metadata, model_name, out_dir, n_modes=3): | |
| """Dominant eigenvectors of Q^(Frob) as layerΓhead heatmaps with projections. | |
| Top row: mode heatmaps (layer Γ head) + layer loading bar chart. | |
| Bottom row: head loading bar charts aligned under each mode + blank cells. | |
| """ | |
| n_layers = metadata["n_layers"] | |
| n_heads = metadata["n_heads"] | |
| N = n_layers * n_heads | |
| eigvals, eigvecs = np.linalg.eigh(Q_frob) | |
| idx = np.argsort(eigvals)[::-1] | |
| eigvals = eigvals[idx] | |
| eigvecs = eigvecs[:, idx] | |
| d_head = metadata["head_dim"] | |
| gamma = N / (d_head ** 2) | |
| lam_plus = (1 + np.sqrt(gamma)) ** 2 | |
| n_outliers = int(np.sum(eigvals > lam_plus)) | |
| n_show = max(1, min(n_modes, n_outliers, 3)) | |
| n_cols = n_show + 1 | |
| fig, axes = plt.subplots( | |
| 2, n_cols, | |
| figsize=(5 * n_cols, 9), | |
| gridspec_kw={ | |
| "width_ratios": [1] * n_show + [0.6], | |
| "height_ratios": [1.2, 1], | |
| }, | |
| ) | |
| if n_cols == 1: | |
| axes = axes.reshape(2, 1) | |
| mode_colors = ["#636EFA", "#EF553B", "#00CC96"] | |
| # ββ Top row: mode heatmaps + layer loading ββ | |
| for k in range(n_show): | |
| ax = axes[0, k] | |
| v = eigvecs[:, k] | |
| v_grid = v.reshape(n_layers, n_heads) | |
| vmax = np.max(np.abs(v_grid)) | |
| im = ax.imshow(v_grid, cmap="RdBu_r", vmin=-vmax, vmax=vmax, aspect="auto") | |
| ax.set_xlabel("Head") | |
| ax.set_ylabel("Layer") | |
| ax.set_title(f"Mode {k + 1}: $\\lambda_{{{k + 1}}}$ = {eigvals[k]:.1f}") | |
| fig.colorbar(im, ax=ax, shrink=0.7) | |
| # Top-right: layer loading | |
| ax = axes[0, -1] | |
| layers_arr = np.arange(n_layers) | |
| width = 0.8 / n_show | |
| for k in range(n_show): | |
| v = eigvecs[:, k] | |
| layer_loading = np.array([np.sum(v[l * n_heads:(l + 1) * n_heads] ** 2) | |
| for l in range(n_layers)]) | |
| ax.barh(layers_arr + k * width, layer_loading, height=width, | |
| color=mode_colors[k % len(mode_colors)], alpha=0.7, | |
| label=f"Mode {k + 1}") | |
| ax.set_ylabel("Layer") | |
| ax.set_xlabel("$\\sum_h v^2_{(\\ell,h)}$") | |
| ax.set_title("Layer loading") | |
| ax.invert_yaxis() | |
| ax.legend(fontsize=8) | |
| # ββ Bottom row: head loading per mode ββ | |
| heads_arr = np.arange(n_heads) | |
| for k in range(n_show): | |
| ax = axes[1, k] | |
| v = eigvecs[:, k] | |
| head_loading = np.array([np.sum(v[h::n_heads] ** 2) | |
| for h in range(n_heads)]) | |
| ax.bar(heads_arr, head_loading, | |
| color=mode_colors[k % len(mode_colors)], alpha=0.7) | |
| ax.set_xlabel("Head") | |
| ax.set_ylabel("$\\sum_\\ell v^2_{(\\ell,h)}$") | |
| ax.set_title(f"Head loading β Mode {k + 1}") | |
| # Bottom-right: blank | |
| axes[1, -1].axis("off") | |
| fig.suptitle(f"{model_name} β Dominant eigenvectors of " | |
| f"$Q_{{hh'}}^{{\\mathrm{{(Frob)}}}}$\n" | |
| f"({n_outliers} outlier{'s' if n_outliers != 1 else ''} " | |
| f"above MP edge at $\\lambda$ = {lam_plus:.2f})", | |
| fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, f"{model_name}_dominant_eigenvectors.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| # ββ Cross-correlation heatmap βββββββββββββββββββββββββββββββββββββββββ | |
| def plot_cross_heatmap(Q, keys, metric_name, label, model_name, out_dir): | |
| """Heatmap for a cross-correlation matrix (not necessarily symmetric). | |
| The diagonal shows intra-head cross-circuit coupling. | |
| """ | |
| fig, ax = plt.subplots(figsize=(10, 9)) | |
| bounds = _layer_boundaries(keys) | |
| n = Q.shape[0] | |
| if _is_divergence(metric_name): | |
| cmap = "viridis_r" | |
| im = ax.imshow(Q, cmap=cmap, aspect="equal") | |
| else: | |
| vmax = np.percentile(np.abs(Q), 98) | |
| cmap = "RdBu_r" | |
| im = ax.imshow(Q, cmap=cmap, aspect="equal", vmin=-vmax, vmax=vmax) | |
| for b in bounds: | |
| ax.axhline(b - 0.5, color="white", linewidth=0.5, alpha=0.8) | |
| ax.axvline(b - 0.5, color="white", linewidth=0.5, alpha=0.8) | |
| layers = sorted(set(k[0] for k in keys)) | |
| n_per = len(keys) // len(layers) | |
| tick_pos = [l * n_per + n_per // 2 for l in range(len(layers))] | |
| ax.set_xticks(tick_pos) | |
| ax.set_xticklabels([str(l) for l in layers], fontsize=9) | |
| ax.set_yticks(tick_pos) | |
| ax.set_yticklabels([str(l) for l in layers], fontsize=9) | |
| parts = label.split("_vs_") | |
| ax.set_xlabel(f"Layer ({parts[1] if len(parts) > 1 else 'B'})") | |
| ax.set_ylabel(f"Layer ({parts[0] if len(parts) > 0 else 'A'})") | |
| ax.set_title(f"{model_name} β Cross-correlation ({label})\n" | |
| f"{_metric_display(metric_name)}") | |
| fig.colorbar(im, ax=ax, shrink=0.8, label=_metric_display(metric_name)) | |
| fig.tight_layout() | |
| safe_label = label.replace("/", "_") | |
| fpath = os.path.join(out_dir, | |
| f"{model_name}_cross_{safe_label}_{metric_name}.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| def plot_cross_diagonal(cross_Q_dict, keys, label, model_name, out_dir): | |
| """Plot the diagonal of cross-correlation matrices (intra-head cross-circuit). | |
| Shows how correlated QK and OV (or W and b) are for the same head, | |
| as a function of layer. | |
| """ | |
| metrics = list(cross_Q_dict.keys()) | |
| layers = np.array([k[0] for k in keys]) | |
| heads = np.array([k[1] for k in keys]) | |
| fig, axes = plt.subplots(1, len(metrics), | |
| figsize=(6 * len(metrics), 4.5)) | |
| if len(metrics) == 1: | |
| axes = [axes] | |
| for ax, m in zip(axes, metrics): | |
| diag = np.diag(cross_Q_dict[m]) | |
| unique_layers = np.unique(layers) | |
| layer_means = [np.mean(diag[layers == l]) for l in unique_layers] | |
| layer_stds = [np.std(diag[layers == l]) for l in unique_layers] | |
| # scatter all heads | |
| ax.scatter(layers, diag, alpha=0.3, s=15, color="#636EFA", zorder=2) | |
| # layer means | |
| ax.errorbar(unique_layers, layer_means, yerr=layer_stds, | |
| fmt="o-", markersize=6, capsize=3, color="#EF553B", | |
| linewidth=2, zorder=3, label="layer mean") | |
| ax.set_xlabel("Layer") | |
| ax.set_ylabel(_metric_display(m)) | |
| ax.set_title(f"Intra-head {label}", fontsize=10) | |
| ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
| ax.legend(fontsize=8) | |
| fig.suptitle(f"{model_name} β Same-head cross-circuit coupling", | |
| fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| safe_label = label.replace("/", "_") | |
| fpath = os.path.join(out_dir, | |
| f"{model_name}_cross_diagonal_{safe_label}.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| # ββ Multi-model comparison plots ββββββββββββββββββββββββββββββββββββββ | |
| MODEL_COLORS = [ | |
| "#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A", | |
| "#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52", | |
| "#E45756", | |
| ] | |
| def plot_eigenvalue_comparison(data_dir, models, revision="main", | |
| weight_type="W_QK", metric="frob_cosine", | |
| out_dir="."): | |
| """Overlay eigenvalue spectra of Q for multiple models on one plot.""" | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| for i, model in enumerate(models): | |
| try: | |
| r = load_results(data_dir, model, revision, weight_type) | |
| except FileNotFoundError: | |
| continue | |
| if metric not in r["eigenvalues"]: | |
| continue | |
| eigvals = r["eigenvalues"][metric] | |
| abs_eig = np.sort(np.abs(eigvals))[::-1] | |
| N = r["metadata"]["n_layers"] * r["metadata"]["n_heads"] | |
| color = MODEL_COLORS[i % len(MODEL_COLORS)] | |
| ax.plot(abs_eig, "o-", markersize=3, color=color, | |
| label=f"{model} ({N} heads)", alpha=0.8) | |
| ax.set_yscale("log") | |
| ax.set_xlabel("Index") | |
| ax.set_ylabel("$|\\lambda|$") | |
| ax.set_title(f"Eigenvalue spectra of $Q_{{hh'}}$ ({_metric_display(metric)})\n" | |
| f"Component: {weight_type}") | |
| ax.legend(fontsize=8, loc="upper right") | |
| ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, | |
| f"all_models_{weight_type}_eigenvalues_{metric}.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| def plot_eigen_stats_comparison(data_dir, models, revision="main", | |
| weight_type="W_QK", out_dir="."): | |
| """Bar chart comparing condition number, NPR, stable rank across models.""" | |
| names, conds, nprs, sranks = [], [], [], [] | |
| mp_conds, mp_nprs, mp_sranks = [], [], [] | |
| for model in models: | |
| try: | |
| r = load_results(data_dir, model, revision, weight_type) | |
| except FileNotFoundError: | |
| continue | |
| if "frob_cosine" not in r["Q"]: | |
| continue | |
| N = r["metadata"]["n_layers"] * r["metadata"]["n_heads"] | |
| d_head = r["metadata"]["head_dim"] | |
| gamma = N / (d_head ** 2) | |
| eigvals = np.linalg.eigvalsh(r["Q"]["frob_cosine"])[::-1] | |
| s = compute_Q_eigen_stats(eigvals, gamma) | |
| names.append(model) | |
| conds.append(s["condition_number"]) | |
| nprs.append(s["npr"]) | |
| sranks.append(s["stable_rank"]) | |
| mp_conds.append(s["mp_condition_number"]) | |
| mp_nprs.append(s["mp_npr"]) | |
| mp_sranks.append(s["mp_stable_rank"]) | |
| if not names: | |
| return None | |
| fig, axes = plt.subplots(1, 3, figsize=(16, 5)) | |
| x = np.arange(len(names)) | |
| w = 0.35 | |
| for ax, meas, pred, title, ylabel in [ | |
| (axes[0], conds, mp_conds, "Condition number $C$", "$\\lambda_{max}/\\lambda_{min}$"), | |
| (axes[1], nprs, mp_nprs, "NPR", "$(\\Sigma\\lambda)^2 / (N \\Sigma\\lambda^2)$"), | |
| (axes[2], sranks, mp_sranks, "Stable rank", "$\\Sigma\\lambda^2 / \\lambda_{max}^2$"), | |
| ]: | |
| ax.bar(x - w / 2, meas, w, label="Measured", color="#EF553B", alpha=0.8) | |
| ax.bar(x + w / 2, pred, w, label="MP prediction", color="#636EFA", alpha=0.8) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(names, rotation=45, ha="right", fontsize=8) | |
| ax.set_title(title) | |
| ax.set_ylabel(ylabel) | |
| ax.legend(fontsize=8) | |
| # Log scale for condition number (huge dynamic range) | |
| axes[0].set_yscale("log") | |
| fig.suptitle(f"Q eigenvalue statistics vs. Marchenko-Pastur ({weight_type})", | |
| fontsize=TITLE_SIZE) | |
| fig.tight_layout() | |
| fpath = os.path.join(out_dir, | |
| f"all_models_{weight_type}_eigen_stats.png") | |
| fig.savefig(fpath, dpi=DPI, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" {fpath}") | |
| return fpath | |
| # ββ Auto-discovery helpers ββββββββββββββββββββββββββββββββββββββββββββ | |
| def discover_weight_types(data_dir, model, revision="main"): | |
| """Find all weight types (W_QK, W_OV, etc.) with saved data for a model.""" | |
| import glob | |
| pattern = os.path.join(data_dir, f"{model}_{revision}_*_metadata.json") | |
| weight_types = [] | |
| cross_labels = [] | |
| for path in sorted(glob.glob(pattern)): | |
| fname = os.path.basename(path) | |
| suffix = "_metadata.json" | |
| prefix = f"{model}_{revision}_" | |
| mid = fname[len(prefix):-len(suffix)] | |
| # Check if it's a cross-correlation (contains "_vs_") | |
| if "_vs_" in mid: | |
| cross_labels.append(mid) | |
| else: | |
| weight_types.append(mid) | |
| return weight_types, cross_labels | |
| def load_cross_results(data_dir, model, label, revision="main"): | |
| """Load saved cross-correlation data.""" | |
| prefix = f"{model}_{revision}_{label}" | |
| with open(os.path.join(data_dir, f"{prefix}_metadata.json")) as f: | |
| metadata = json.load(f) | |
| Q_data = np.load(os.path.join(data_dir, f"{prefix}_Q.npz")) | |
| Q = {k.replace("Q_", ""): Q_data[k] for k in Q_data.files} | |
| keys = [tuple(k) for k in metadata["head_index"]] | |
| return {"Q": Q, "metadata": metadata, "keys": keys} | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_weight_type(data_dir, model, revision, weight_type, out_dir, metrics_filter=None): | |
| """Generate all standard plots for one weight type.""" | |
| print(f"\n{'='*50}") | |
| print(f"Weight type: {weight_type}") | |
| print(f"{'='*50}") | |
| r = load_results(data_dir, model, revision, weight_type) | |
| metrics = metrics_filter or r["metadata"]["metrics"] | |
| Q = {m: r["Q"][m] for m in metrics if m in r["Q"]} | |
| P_Q = {m: r["P_Q"][m] for m in metrics if m in r["P_Q"]} | |
| eig = {m: r["eigenvalues"][m] for m in metrics if m in r["eigenvalues"]} | |
| blk = {m: r["block_means"][m] for m in metrics if m in r["block_means"]} | |
| summary = {m: r["summary"][m] for m in metrics if m in r["summary"]} | |
| # Always include component in filename for consistency | |
| name_tag = f"{model}_{weight_type}" | |
| print(f"Generating figures for: {list(Q.keys())}") | |
| print("Heatmaps:") | |
| for m in Q: | |
| plot_heatmap(Q[m], r["keys"], m, name_tag, out_dir) | |
| if P_Q: | |
| print("P(Q) distributions:") | |
| plot_P_Q(P_Q, summary, name_tag, out_dir) | |
| if eig: | |
| print("Eigenvalue spectra:") | |
| plot_eigenvalues(eig, name_tag, out_dir) | |
| if blk: | |
| print("Block means:") | |
| plot_block_means(blk, r["metadata"], name_tag, out_dir) | |
| if Q: | |
| print("Correlation vs. layer distance:") | |
| plot_correlation_vs_layer_distance(P_Q, r["keys"], Q, name_tag, out_dir) | |
| if "frob_cosine" in Q: | |
| print("MP overlay:") | |
| _, eigen_stats = plot_mp_overlay(Q["frob_cosine"], r["metadata"], | |
| name_tag, out_dir) | |
| # Save eigenvalue stats as JSON | |
| stats_path = os.path.join(out_dir, f"{name_tag}_eigen_stats.json") | |
| with open(stats_path, "w") as f: | |
| json.dump({k: float(v) for k, v in eigen_stats.items()}, f, indent=2) | |
| print(f" {stats_path}") | |
| print("Dominant eigenvectors:") | |
| plot_dominant_eigenvector(Q["frob_cosine"], r["metadata"], | |
| name_tag, out_dir) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Plot head-head correlations from saved data") | |
| parser.add_argument("--data", type=str, default="corr_out", | |
| help="Directory with saved correlation outputs") | |
| parser.add_argument("--model", type=str, default="gpt2") | |
| parser.add_argument("--revision", type=str, default="main") | |
| parser.add_argument("--weight-type", type=str, default=None, | |
| help="Specific weight type (default: auto-discover all)") | |
| parser.add_argument("--out", type=str, default=None, | |
| help="Output figure directory (default: {data}/figures)") | |
| parser.add_argument("--metrics", nargs="+", default=None, | |
| help="Subset of metrics to plot (default: all)") | |
| args = parser.parse_args() | |
| out_dir = args.out or os.path.join(args.data, "figures") | |
| os.makedirs(out_dir, exist_ok=True) | |
| if args.weight_type: | |
| # Single weight type (backward compat) | |
| plot_weight_type(args.data, args.model, args.revision, | |
| args.weight_type, out_dir, args.metrics) | |
| else: | |
| # Auto-discover all weight types and cross-correlations | |
| weight_types, cross_labels = discover_weight_types( | |
| args.data, args.model, args.revision) | |
| if not weight_types and not cross_labels: | |
| # Fallback: try W_QK (old naming) | |
| weight_types = ["W_QK"] | |
| print(f"Found weight types: {weight_types}") | |
| if cross_labels: | |
| print(f"Found cross-correlations: {cross_labels}") | |
| # Self-correlations | |
| for wt in weight_types: | |
| try: | |
| plot_weight_type(args.data, args.model, args.revision, | |
| wt, out_dir, args.metrics) | |
| except Exception as e: | |
| print(f" *** Error plotting {wt}: {e}") | |
| # Cross-correlations | |
| for label in cross_labels: | |
| try: | |
| print(f"\n{'='*50}") | |
| print(f"Cross-correlation: {label}") | |
| print(f"{'='*50}") | |
| cr = load_cross_results(args.data, args.model, | |
| label, args.revision) | |
| name_tag = f"{args.model}" | |
| for m, Q_mat in cr["Q"].items(): | |
| plot_cross_heatmap(Q_mat, cr["keys"], m, label, | |
| name_tag, out_dir) | |
| if cr["Q"]: | |
| plot_cross_diagonal(cr["Q"], cr["keys"], label, | |
| name_tag, out_dir) | |
| except Exception as e: | |
| print(f" *** Error plotting {label}: {e}") | |
| print(f"\nDone. Figures in {out_dir}") | |
| if __name__ == "__main__": | |
| main() | |