| |
|
| | """
|
| | Spectral analysis of Prisma / Circuit Transformer checkpoints.
|
| |
|
| | Computes SVD spectra of weight matrices and (optionally) activation covariances,
|
| | revealing how the model organizes information geometrically.
|
| |
|
| | Analyses:
|
| | 1. Weight spectra β singular value distributions per matrix
|
| | 2. Effective rank β how many dimensions carry real signal
|
| | 3. Power-law fit β Martin & Mahoney alpha exponent (training quality)
|
| | 4. MP bound β Marchenko-Pastur separation of signal vs noise
|
| | 5. Mirror comparison β expand vs compress activation spectra (Prisma-specific)
|
| | 6. Embedding alignmentβ spectral similarity between embed and final hidden states
|
| | 7. Layer-wise summary β effective rank progression through the network (the lens)
|
| |
|
| | Usage:
|
| | # Weight-only analysis (no data needed)
|
| | python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt
|
| |
|
| | # Full analysis with activation spectra (needs data)
|
| | python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt \
|
| | --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train --num-samples 512
|
| |
|
| | # Compare two checkpoints
|
| | python -m circuits.scripts.spectral_analysis \
|
| | --checkpoint path/to/prisma.pt --checkpoint-b path/to/standard.pt
|
| |
|
| | # Compare against HuggingFace model
|
| | python -m circuits.scripts.spectral_analysis \
|
| | --checkpoint path/to/prisma.pt --hf-model gpt2-medium
|
| | """
|
| |
|
| | import argparse
|
| | import json
|
| | import sys
|
| | import os
|
| | from pathlib import Path
|
| | from collections import defaultdict
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import matplotlib
|
| | matplotlib.use("Agg")
|
| | import matplotlib.pyplot as plt
|
| | from matplotlib.gridspec import GridSpec
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
|
| | """Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
|
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| | from circuits.config import CircuitConfig
|
| | from circuits.model import CircuitTransformer
|
| | from circuits.mirrored import MirroredConfig, MirroredTransformer
|
| |
|
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| | model_type = ckpt.get("model_type", "standard")
|
| | config_dict = ckpt.get("config", {})
|
| |
|
| | if model_type == "mirrored":
|
| | if config_dict.get("dual_gate_middle"):
|
| | config_dict.pop("dual_gate_middle")
|
| | config = MirroredConfig.from_dict(config_dict)
|
| | model = MirroredTransformer(config)
|
| | else:
|
| | config = CircuitConfig.from_dict(config_dict)
|
| | model = CircuitTransformer(config)
|
| |
|
| | state_dict = ckpt["model"]
|
| | if any(k.startswith("_orig_mod.") for k in state_dict):
|
| | state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| | model.load_state_dict(state_dict, strict=False)
|
| | model.to(device).eval()
|
| |
|
| | return model, config_dict, model_type
|
| |
|
| |
|
| | def load_hf_model(model_name: str, device: str = "cpu"):
|
| | """Load a HuggingFace causal LM."""
|
| | from transformers import AutoModelForCausalLM
|
| | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
|
| | model.to(device).eval()
|
| | return model
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def compute_singular_values(weight: torch.Tensor) -> np.ndarray:
|
| | """Compute singular values of a 2D weight matrix."""
|
| | w = weight.detach().float().cpu()
|
| | if w.ndim != 2:
|
| | return None
|
| | sv = torch.linalg.svdvals(w).numpy()
|
| | return sv
|
| |
|
| |
|
| | def effective_rank(sv: np.ndarray) -> float:
|
| | """Entropy-based effective rank (Roy & Vetterli, 2007).
|
| |
|
| | erank = exp(H(p)) where p_i = sigma_i / sum(sigma)
|
| | and H is Shannon entropy. Ranges from 1 (rank-1) to min(m,n) (full rank).
|
| | """
|
| | sv = sv[sv > 1e-10]
|
| | if len(sv) == 0:
|
| | return 0.0
|
| | p = sv / sv.sum()
|
| | entropy = -(p * np.log(p)).sum()
|
| | return float(np.exp(entropy))
|
| |
|
| |
|
| | def stable_rank(sv: np.ndarray) -> float:
|
| | """Stable rank = ||W||_F^2 / ||W||_2^2 = sum(sigma^2) / max(sigma)^2."""
|
| | if len(sv) == 0 or sv[0] < 1e-10:
|
| | return 0.0
|
| | return float((sv ** 2).sum() / (sv[0] ** 2))
|
| |
|
| |
|
| | def marchenko_pastur_bound(m: int, n: int, sv: np.ndarray) -> float:
|
| | """Estimate Marchenko-Pastur upper edge.
|
| |
|
| | For a random matrix with variance sigma^2, the MP upper bound is
|
| | sigma * (1 + sqrt(m/n))^2 (assuming m >= n).
|
| | We estimate sigma from the bulk of singular values.
|
| | """
|
| | gamma = max(m, n) / min(m, n)
|
| |
|
| | bottom_half = sv[len(sv) // 2:]
|
| | if len(bottom_half) == 0:
|
| | return sv[-1] if len(sv) > 0 else 0.0
|
| | sigma_est = float(np.median(bottom_half)) / np.sqrt(max(m, n))
|
| | mp_upper = sigma_est * (1.0 + np.sqrt(gamma)) ** 2 * np.sqrt(min(m, n))
|
| | return mp_upper
|
| |
|
| |
|
| | def fit_power_law(sv: np.ndarray, fit_fraction: float = 0.8) -> tuple[float, float]:
|
| | """Fit power law to singular value distribution tail.
|
| |
|
| | Returns (alpha, r_squared). alpha < 2 = heavy-tailed (well-trained).
|
| | """
|
| | sv = sv[sv > 1e-10]
|
| | if len(sv) < 10:
|
| | return 0.0, 0.0
|
| |
|
| | n_fit = max(10, int(len(sv) * fit_fraction))
|
| | sv_fit = sv[:n_fit]
|
| |
|
| | log_rank = np.log(np.arange(1, n_fit + 1))
|
| | log_sv = np.log(sv_fit)
|
| |
|
| |
|
| | coeffs = np.polyfit(log_rank, log_sv, 1)
|
| | alpha = -coeffs[0]
|
| |
|
| |
|
| | predicted = np.polyval(coeffs, log_rank)
|
| | ss_res = ((log_sv - predicted) ** 2).sum()
|
| | ss_tot = ((log_sv - log_sv.mean()) ** 2).sum()
|
| | r_sq = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
| |
|
| | return float(alpha), float(r_sq)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def analyze_weight_spectra(model: nn.Module, model_label: str = "model") -> dict:
|
| | """Compute SVD spectra for all 2D weight matrices."""
|
| | results = {}
|
| | for name, param in model.named_parameters():
|
| | if param.ndim != 2:
|
| | continue
|
| | sv = compute_singular_values(param)
|
| | if sv is None:
|
| | continue
|
| | m, n = param.shape
|
| | mp_bound = marchenko_pastur_bound(m, n, sv)
|
| | n_above_mp = int((sv > mp_bound).sum())
|
| | alpha, r_sq = fit_power_law(sv)
|
| |
|
| | results[name] = {
|
| | "shape": (m, n),
|
| | "singular_values": sv,
|
| | "effective_rank": effective_rank(sv),
|
| | "stable_rank": stable_rank(sv),
|
| | "spectral_norm": float(sv[0]),
|
| | "frobenius_norm": float(np.sqrt((sv ** 2).sum())),
|
| | "mp_bound": mp_bound,
|
| | "n_above_mp": n_above_mp,
|
| | "n_total": len(sv),
|
| | "signal_ratio": n_above_mp / len(sv) if len(sv) > 0 else 0,
|
| | "alpha": alpha,
|
| | "alpha_r2": r_sq,
|
| | "condition_number": float(sv[0] / sv[-1]) if sv[-1] > 1e-10 else float("inf"),
|
| | }
|
| | return results
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def collect_activations(model, input_ids: torch.Tensor,
|
| | word_positions: torch.Tensor = None,
|
| | model_type: str = "standard") -> dict[str, torch.Tensor]:
|
| | """Run a forward pass and collect intermediate activations via hooks."""
|
| | activations = {}
|
| | hooks = []
|
| |
|
| | def make_hook(name):
|
| | def hook_fn(module, input, output):
|
| | if isinstance(output, tuple):
|
| | out = output[0]
|
| | else:
|
| | out = output
|
| |
|
| | activations[name] = out.detach().float().cpu()
|
| | return hook_fn
|
| |
|
| |
|
| | if model_type == "mirrored":
|
| |
|
| | for i, block in enumerate(model.mirror_blocks):
|
| | hooks.append(block.register_forward_hook(make_hook(f"expand_{i}")))
|
| |
|
| | for i, block in enumerate(model.middle_blocks):
|
| | hooks.append(block.register_forward_hook(make_hook(f"middle_{i}")))
|
| |
|
| |
|
| | else:
|
| | for i, block in enumerate(model.layers):
|
| | hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
|
| |
|
| |
|
| | hooks.append(model.embed.register_forward_hook(make_hook("embedding")))
|
| |
|
| | with torch.no_grad():
|
| | kwargs = {}
|
| | if word_positions is not None:
|
| | kwargs["word_positions"] = word_positions
|
| | model(input_ids, **kwargs)
|
| |
|
| | for h in hooks:
|
| | h.remove()
|
| |
|
| | return activations
|
| |
|
| |
|
| | def collect_mirrored_activations(model, input_ids: torch.Tensor,
|
| | word_positions: torch.Tensor = None) -> dict[str, torch.Tensor]:
|
| | """Collect activations from a MirroredTransformer, separating expand and compress phases.
|
| |
|
| | This manually runs the forward pass to capture compress-phase activations
|
| | from the reversed mirror blocks.
|
| | """
|
| | import math
|
| |
|
| | activations = {}
|
| |
|
| | with torch.no_grad():
|
| |
|
| | x = model.embed(input_ids)
|
| | if model.embed_proj is not None:
|
| | import torch.nn.functional as F
|
| | if model.embed_g3 is not None:
|
| | g4 = F.silu(model.embed_g4(x))
|
| | g3 = F.silu(model.embed_g3(x) * g4)
|
| | x = model.embed_proj(x) * g3
|
| | else:
|
| | x = F.silu(model.embed_proj(x))
|
| | x = x * model.embed_scale
|
| | activations["embedding"] = x.detach().float().cpu()
|
| |
|
| |
|
| | for i, block in enumerate(model.mirror_blocks):
|
| | x, _ = block(x, word_positions=word_positions)
|
| | activations[f"expand_{i}"] = x.detach().float().cpu()
|
| |
|
| |
|
| | for i, block in enumerate(model.middle_blocks):
|
| | x, _ = block(x, word_positions=word_positions)
|
| | activations[f"middle_{i}"] = x.detach().float().cpu()
|
| |
|
| |
|
| | for i in reversed(range(len(model.mirror_blocks))):
|
| | x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
|
| | compress_idx = len(model.mirror_blocks) - 1 - i
|
| | activations[f"compress_{compress_idx}"] = x.detach().float().cpu()
|
| |
|
| |
|
| | x = model.norm(x)
|
| | activations["final_norm"] = x.detach().float().cpu()
|
| |
|
| | return activations
|
| |
|
| |
|
| | def activation_spectrum(act: torch.Tensor, max_components: int = 256) -> dict:
|
| | """Compute eigenspectrum of activation covariance.
|
| |
|
| | act: [B, T, D] β reshape to [B*T, D], compute covariance, eigendecompose.
|
| | """
|
| |
|
| | flat = act.reshape(-1, act.shape[-1])
|
| | N, D = flat.shape
|
| |
|
| | if N < 2:
|
| | return None
|
| |
|
| |
|
| | flat = flat - flat.mean(dim=0, keepdim=True)
|
| |
|
| |
|
| | n_components = min(max_components, D, N)
|
| | try:
|
| | U, S, Vh = torch.pca_lowrank(flat, q=n_components)
|
| | eigenvalues = (S ** 2 / (N - 1)).numpy()
|
| | except Exception:
|
| |
|
| | cov = (flat.T @ flat) / (N - 1)
|
| | eigenvalues = torch.linalg.eigvalsh(cov).flip(0).numpy()
|
| | eigenvalues = eigenvalues[:max_components]
|
| |
|
| | eigenvalues = eigenvalues[eigenvalues > 1e-10]
|
| |
|
| | return {
|
| | "eigenvalues": eigenvalues,
|
| | "effective_rank": effective_rank(np.sqrt(np.maximum(eigenvalues, 0))),
|
| | "total_variance": float(eigenvalues.sum()),
|
| | "top1_variance_ratio": float(eigenvalues[0] / eigenvalues.sum()) if len(eigenvalues) > 0 else 0,
|
| | "top10_variance_ratio": float(eigenvalues[:10].sum() / eigenvalues.sum()) if len(eigenvalues) >= 10 else 0,
|
| | "n_components": len(eigenvalues),
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def plot_weight_spectra(results: dict, output_dir: Path, model_label: str = "model",
|
| | results_b: dict = None, model_b_label: str = "model_b"):
|
| | """Plot singular value distributions for all weight matrices."""
|
| |
|
| | groups = defaultdict(list)
|
| | for name, data in results.items():
|
| |
|
| | if "attn" in name and ("q_proj" in name or "wq" in name):
|
| | groups["attention_Q"].append((name, data))
|
| | elif "attn" in name and ("k_proj" in name or "wk" in name):
|
| | groups["attention_K"].append((name, data))
|
| | elif "attn" in name and ("v_proj" in name or "wv" in name):
|
| | groups["attention_V"].append((name, data))
|
| | elif "attn" in name and ("o_proj" in name or "wo" in name):
|
| | groups["attention_O"].append((name, data))
|
| | elif "w1" in name or "up_proj" in name:
|
| | groups["ffn_W1"].append((name, data))
|
| | elif "w2" in name or "down_proj" in name:
|
| | groups["ffn_W2"].append((name, data))
|
| | elif "w3" in name or "gate_proj" in name:
|
| | groups["ffn_gate_W3"].append((name, data))
|
| | elif "w4" in name:
|
| | groups["ffn_gate_W4"].append((name, data))
|
| | elif "embed" in name or "wte" in name:
|
| | groups["embedding"].append((name, data))
|
| | elif "lm_head" in name:
|
| | groups["lm_head"].append((name, data))
|
| | else:
|
| | groups["other"].append((name, data))
|
| |
|
| |
|
| | for group_name, items in groups.items():
|
| | if not items:
|
| | continue
|
| |
|
| | fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| | fig.suptitle(f"{model_label} β {group_name} weight spectra", fontsize=13)
|
| |
|
| | ax_linear, ax_log = axes
|
| |
|
| | cmap = plt.cm.viridis(np.linspace(0.1, 0.9, len(items)))
|
| | for idx, (name, data) in enumerate(items):
|
| | sv = data["singular_values"]
|
| | short_name = name.split(".")[-2] + "." + name.split(".")[-1] if "." in name else name
|
| | ax_linear.plot(sv, color=cmap[idx], alpha=0.7, linewidth=0.8, label=short_name)
|
| | ax_log.loglog(np.arange(1, len(sv) + 1), sv, color=cmap[idx], alpha=0.7,
|
| | linewidth=0.8, label=short_name)
|
| |
|
| | ax_linear.axhline(data["mp_bound"], color=cmap[idx], linestyle=":", alpha=0.3)
|
| |
|
| | ax_linear.set_xlabel("Rank")
|
| | ax_linear.set_ylabel("Singular value")
|
| | ax_linear.set_title("Linear scale")
|
| | ax_linear.legend(fontsize=6, ncol=2)
|
| |
|
| | ax_log.set_xlabel("Rank")
|
| | ax_log.set_ylabel("Singular value")
|
| | ax_log.set_title("Log-log scale (power law)")
|
| | ax_log.legend(fontsize=6, ncol=2)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / f"weight_spectra_{group_name}.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_effective_rank_progression(results: dict, output_dir: Path,
|
| | model_label: str = "model",
|
| | results_b: dict = None,
|
| | model_b_label: str = "model_b"):
|
| | """Plot effective rank per layer β the biconcave lens in eigenvalues."""
|
| |
|
| | layer_data = []
|
| | for name, data in sorted(results.items()):
|
| | if "w1" in name or "up_proj" in name:
|
| |
|
| | parts = name.split(".")
|
| | layer_label = name
|
| | for p in parts:
|
| | if p.isdigit():
|
| | layer_label = p
|
| | break
|
| | layer_data.append((name, data["effective_rank"], data["stable_rank"],
|
| | data["alpha"], data["signal_ratio"], layer_label))
|
| |
|
| | if not layer_data:
|
| | return
|
| |
|
| | fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| | fig.suptitle(f"{model_label} β Layer-wise spectral properties (FFN W1)", fontsize=13)
|
| |
|
| | names = [d[0] for d in layer_data]
|
| | x = range(len(layer_data))
|
| | short_labels = [d[5] for d in layer_data]
|
| |
|
| |
|
| | axes[0, 0].bar(x, [d[1] for d in layer_data], color="steelblue", alpha=0.8)
|
| | axes[0, 0].set_ylabel("Effective rank")
|
| | axes[0, 0].set_title("Effective rank (entropy-based)")
|
| | axes[0, 0].set_xticks(x)
|
| | axes[0, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| |
|
| |
|
| | axes[0, 1].bar(x, [d[2] for d in layer_data], color="coral", alpha=0.8)
|
| | axes[0, 1].set_ylabel("Stable rank")
|
| | axes[0, 1].set_title("Stable rank (Frobenius/spectral)")
|
| | axes[0, 1].set_xticks(x)
|
| | axes[0, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| |
|
| |
|
| | axes[1, 0].bar(x, [d[3] for d in layer_data], color="mediumpurple", alpha=0.8)
|
| | axes[1, 0].set_ylabel("Alpha")
|
| | axes[1, 0].set_title("Power-law exponent (lower = heavier tail = more structure)")
|
| | axes[1, 0].axhline(2.0, color="red", linestyle="--", alpha=0.5, label="alpha=2 boundary")
|
| | axes[1, 0].legend(fontsize=8)
|
| | axes[1, 0].set_xticks(x)
|
| | axes[1, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| |
|
| |
|
| | axes[1, 1].bar(x, [d[4] for d in layer_data], color="seagreen", alpha=0.8)
|
| | axes[1, 1].set_ylabel("Signal ratio")
|
| | axes[1, 1].set_title("Fraction of singular values above MP bound")
|
| | axes[1, 1].set_xticks(x)
|
| | axes[1, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "layer_progression.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_activation_spectra(act_spectra: dict, output_dir: Path,
|
| | model_label: str = "model"):
|
| | """Plot activation eigenspectra across layers."""
|
| | if not act_spectra:
|
| | return
|
| |
|
| |
|
| | order_keys = {"embedding": -1, "final_norm": 999}
|
| | def sort_key(name):
|
| | if name in order_keys:
|
| | return order_keys[name]
|
| | parts = name.split("_")
|
| | phase = parts[0]
|
| | idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| | phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
|
| | return phase_offset.get(phase, 300) + idx
|
| |
|
| | sorted_names = sorted(act_spectra.keys(), key=sort_key)
|
| |
|
| |
|
| | fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| | fig.suptitle(f"{model_label} β Activation eigenspectra", fontsize=13)
|
| |
|
| | cmap = plt.cm.coolwarm(np.linspace(0, 1, len(sorted_names)))
|
| | for idx, name in enumerate(sorted_names):
|
| | data = act_spectra[name]
|
| | ev = data["eigenvalues"]
|
| | axes[0].semilogy(ev / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
|
| | axes[1].plot(np.cumsum(ev) / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
|
| |
|
| | axes[0].set_xlabel("Component")
|
| | axes[0].set_ylabel("Normalized eigenvalue (log)")
|
| | axes[0].set_title("Eigenvalue distribution")
|
| | axes[0].legend(fontsize=6, ncol=2)
|
| |
|
| | axes[1].set_xlabel("Component")
|
| | axes[1].set_ylabel("Cumulative variance explained")
|
| | axes[1].set_title("Variance concentration")
|
| | axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4, label="90%")
|
| | axes[1].legend(fontsize=6, ncol=2)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "activation_spectra.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | fig, ax = plt.subplots(figsize=(12, 5))
|
| | fig.suptitle(f"{model_label} β Activation effective rank progression", fontsize=13)
|
| |
|
| | eranks = [act_spectra[n]["effective_rank"] for n in sorted_names]
|
| | colors = []
|
| | for name in sorted_names:
|
| | if "expand" in name:
|
| | colors.append("steelblue")
|
| | elif "middle" in name:
|
| | colors.append("goldenrod")
|
| | elif "compress" in name:
|
| | colors.append("coral")
|
| | else:
|
| | colors.append("gray")
|
| |
|
| | ax.bar(range(len(sorted_names)), eranks, color=colors, alpha=0.8)
|
| | ax.set_xticks(range(len(sorted_names)))
|
| | ax.set_xticklabels(sorted_names, rotation=45, ha="right", fontsize=8)
|
| | ax.set_ylabel("Effective rank")
|
| | ax.set_title("Expand (blue) β Middle (gold) β Compress (coral)")
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "activation_rank_progression.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_mirror_comparison(act_spectra: dict, output_dir: Path,
|
| | model_label: str = "model"):
|
| | """Compare expand vs compress activation spectra for each mirror pair."""
|
| | expand_layers = sorted([n for n in act_spectra if n.startswith("expand_")])
|
| | compress_layers = sorted([n for n in act_spectra if n.startswith("compress_")])
|
| |
|
| | if not expand_layers or not compress_layers:
|
| | return
|
| |
|
| | n_pairs = min(len(expand_layers), len(compress_layers))
|
| | fig, axes = plt.subplots(1, n_pairs, figsize=(4 * n_pairs, 4), squeeze=False)
|
| | fig.suptitle(f"{model_label} β Mirror pair activation spectra (expand vs compress)", fontsize=13)
|
| |
|
| | for i in range(n_pairs):
|
| | ax = axes[0, i]
|
| | exp_ev = act_spectra[expand_layers[i]]["eigenvalues"]
|
| | comp_ev = act_spectra[compress_layers[i]]["eigenvalues"]
|
| |
|
| | n_plot = min(len(exp_ev), len(comp_ev), 100)
|
| | ax.semilogy(exp_ev[:n_plot] / exp_ev.sum(), color="steelblue", alpha=0.8,
|
| | linewidth=1.5, label="expand")
|
| | ax.semilogy(comp_ev[:n_plot] / comp_ev.sum(), color="coral", alpha=0.8,
|
| | linewidth=1.5, label="compress")
|
| |
|
| | exp_er = act_spectra[expand_layers[i]]["effective_rank"]
|
| | comp_er = act_spectra[compress_layers[i]]["effective_rank"]
|
| | ax.set_title(f"Pair {i}\nerank: {exp_er:.0f} / {comp_er:.0f}", fontsize=10)
|
| | ax.set_xlabel("Component")
|
| | if i == 0:
|
| | ax.set_ylabel("Normalized eigenvalue")
|
| | ax.legend(fontsize=8)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "mirror_pair_comparison.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_gate_spectra(results: dict, output_dir: Path, model_label: str = "model"):
|
| | """Compare W3 vs W4 gate weight spectra (G2LU inner vs outer gate)."""
|
| | w3_items = [(n, d) for n, d in sorted(results.items()) if "w3" in n and "ffn" in n]
|
| | w4_items = [(n, d) for n, d in sorted(results.items()) if "w4" in n and "ffn" in n]
|
| |
|
| | if not w3_items or not w4_items:
|
| | return
|
| |
|
| | n_pairs = min(len(w3_items), len(w4_items))
|
| | fig, axes = plt.subplots(2, 1, figsize=(12, 8))
|
| | fig.suptitle(f"{model_label} β G2LU gate spectra (W3 outer vs W4 inner)", fontsize=13)
|
| |
|
| |
|
| | cmap_w3 = plt.cm.Blues(np.linspace(0.3, 0.9, n_pairs))
|
| | cmap_w4 = plt.cm.Reds(np.linspace(0.3, 0.9, n_pairs))
|
| |
|
| | for i in range(n_pairs):
|
| | sv3 = w3_items[i][1]["singular_values"]
|
| | sv4 = w4_items[i][1]["singular_values"]
|
| | axes[0].semilogy(sv3, color=cmap_w3[i], alpha=0.6, linewidth=0.8, label=f"W3 pair {i}")
|
| | axes[0].semilogy(sv4, color=cmap_w4[i], alpha=0.6, linewidth=0.8, label=f"W4 pair {i}")
|
| |
|
| | axes[0].set_xlabel("Rank")
|
| | axes[0].set_ylabel("Singular value (log)")
|
| | axes[0].set_title("Gate weight spectra")
|
| | axes[0].legend(fontsize=6, ncol=4)
|
| |
|
| |
|
| | er_w3 = [w3_items[i][1]["effective_rank"] for i in range(n_pairs)]
|
| | er_w4 = [w4_items[i][1]["effective_rank"] for i in range(n_pairs)]
|
| | x = np.arange(n_pairs)
|
| | axes[1].bar(x - 0.15, er_w3, 0.3, color="steelblue", alpha=0.8, label="W3 (outer gate)")
|
| | axes[1].bar(x + 0.15, er_w4, 0.3, color="coral", alpha=0.8, label="W4 (inner gate)")
|
| | axes[1].set_xlabel("Mirror pair")
|
| | axes[1].set_ylabel("Effective rank")
|
| | axes[1].set_title("Gate effective rank by pair")
|
| | axes[1].set_xticks(x)
|
| | axes[1].legend()
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "gate_spectra.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_embedding_alignment(results: dict, act_spectra: dict, output_dir: Path,
|
| | model_label: str = "model"):
|
| | """Compare embedding weight spectrum with final layer activation spectrum."""
|
| | embed_data = None
|
| | for name, data in results.items():
|
| | if "embed" in name.lower() and "proj" not in name.lower() and "g3" not in name.lower() and "g4" not in name.lower():
|
| | embed_data = data
|
| | break
|
| |
|
| | final_act = act_spectra.get("final_norm") or act_spectra.get("compress_0")
|
| | if embed_data is None or final_act is None:
|
| | return
|
| |
|
| | fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| | fig.suptitle(f"{model_label} β Embedding vs final activation spectra", fontsize=13)
|
| |
|
| |
|
| | sv_embed = embed_data["singular_values"]
|
| | ev_final = final_act["eigenvalues"]
|
| | sv_embed_norm = sv_embed / sv_embed.sum()
|
| | ev_final_norm = ev_final / ev_final.sum()
|
| |
|
| | n_plot = min(len(sv_embed_norm), len(ev_final_norm), 200)
|
| | axes[0].semilogy(sv_embed_norm[:n_plot], color="steelblue", linewidth=1.5,
|
| | label=f"Embedding (erank={embed_data['effective_rank']:.0f})")
|
| | axes[0].semilogy(ev_final_norm[:n_plot], color="coral", linewidth=1.5,
|
| | label=f"Final act (erank={final_act['effective_rank']:.0f})")
|
| | axes[0].set_xlabel("Component")
|
| | axes[0].set_ylabel("Normalized value (log)")
|
| | axes[0].set_title("Spectral shape comparison")
|
| | axes[0].legend()
|
| |
|
| |
|
| | axes[1].plot(np.cumsum(sv_embed_norm[:n_plot]), color="steelblue", linewidth=1.5, label="Embedding")
|
| | axes[1].plot(np.cumsum(ev_final_norm[:n_plot]), color="coral", linewidth=1.5, label="Final activation")
|
| | axes[1].set_xlabel("Component")
|
| | axes[1].set_ylabel("Cumulative fraction")
|
| | axes[1].set_title("Variance concentration")
|
| | axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4)
|
| | axes[1].legend()
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "embedding_alignment.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_comparison(results_a: dict, results_b: dict,
|
| | label_a: str, label_b: str,
|
| | output_dir: Path):
|
| | """Side-by-side comparison of two models' spectral properties."""
|
| |
|
| | def extract_ffn_ranks(results):
|
| | ranks = []
|
| | for name, data in sorted(results.items()):
|
| | if ("w1" in name or "up_proj" in name or "c_fc" in name
|
| | or "dense_h_to_4h" in name) and "embed" not in name:
|
| | ranks.append((name, data["effective_rank"], data["stable_rank"], data["alpha"]))
|
| | return ranks
|
| |
|
| | ranks_a = extract_ffn_ranks(results_a)
|
| | ranks_b = extract_ffn_ranks(results_b)
|
| |
|
| | if not ranks_a or not ranks_b:
|
| | return
|
| |
|
| | fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| | fig.suptitle(f"Comparison: {label_a} vs {label_b}", fontsize=13)
|
| |
|
| | n = min(len(ranks_a), len(ranks_b))
|
| | x = np.arange(n)
|
| |
|
| | for ax_idx, (metric_idx, ylabel, title) in enumerate([
|
| | (1, "Effective rank", "Effective rank per layer"),
|
| | (2, "Stable rank", "Stable rank per layer"),
|
| | (3, "Alpha", "Power-law alpha per layer"),
|
| | ]):
|
| | vals_a = [ranks_a[i][metric_idx] for i in range(n)]
|
| | vals_b = [ranks_b[i][metric_idx] for i in range(n)]
|
| | axes[ax_idx].bar(x - 0.15, vals_a, 0.3, color="steelblue", alpha=0.8, label=label_a)
|
| | axes[ax_idx].bar(x + 0.15, vals_b, 0.3, color="coral", alpha=0.8, label=label_b)
|
| | axes[ax_idx].set_xlabel("Layer")
|
| | axes[ax_idx].set_ylabel(ylabel)
|
| | axes[ax_idx].set_title(title)
|
| | axes[ax_idx].legend(fontsize=8)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "comparison.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def print_summary(results: dict, model_label: str, act_spectra: dict = None):
|
| | """Print a concise text summary of spectral analysis."""
|
| | print(f"\n{'='*70}")
|
| | print(f" Spectral Analysis: {model_label}")
|
| | print(f"{'='*70}")
|
| |
|
| |
|
| | components = defaultdict(list)
|
| | for name, data in sorted(results.items()):
|
| | if "w1" in name or "up_proj" in name:
|
| | components["FFN W1 (up)"].append(data)
|
| | elif "w2" in name or "down_proj" in name:
|
| | components["FFN W2 (down)"].append(data)
|
| | elif "w3" in name:
|
| | components["FFN W3 (outer gate)"].append(data)
|
| | elif "w4" in name:
|
| | components["FFN W4 (inner gate)"].append(data)
|
| | elif "embed" in name.lower() and "proj" not in name and "g3" not in name and "g4" not in name:
|
| | components["Embedding"].append(data)
|
| |
|
| | print(f"\n{'Component':<25} {'Shape':>12} {'eRank':>8} {'sRank':>8} {'Alpha':>8} {'Sig%':>8} {'Cond#':>10}")
|
| | print("-" * 85)
|
| | for comp_name, items in components.items():
|
| | for i, data in enumerate(items):
|
| | label = f"{comp_name}" if len(items) == 1 else f"{comp_name}[{i}]"
|
| | shape_str = f"{data['shape'][0]}x{data['shape'][1]}"
|
| | cond = f"{data['condition_number']:.0f}" if data['condition_number'] < 1e6 else "inf"
|
| | print(f"{label:<25} {shape_str:>12} {data['effective_rank']:>8.1f} "
|
| | f"{data['stable_rank']:>8.1f} {data['alpha']:>8.3f} "
|
| | f"{data['signal_ratio']*100:>7.1f}% {cond:>10}")
|
| |
|
| |
|
| | all_alphas = [d["alpha"] for d in results.values() if d["alpha"] > 0]
|
| | all_eranks = [d["effective_rank"] for d in results.values()]
|
| | if all_alphas:
|
| | print(f"\n Mean alpha: {np.mean(all_alphas):.3f} (< 2.0 = heavy-tailed = well-structured)")
|
| | print(f" Mean effective rank: {np.mean(all_eranks):.1f}")
|
| |
|
| |
|
| | if act_spectra:
|
| | print(f"\n Activation spectra:")
|
| | print(f" {'Layer':<25} {'eRank':>8} {'Top1%':>8} {'Top10%':>8}")
|
| | print(" " + "-" * 55)
|
| |
|
| | order_keys = {"embedding": -1, "final_norm": 999}
|
| | def sort_key(name):
|
| | if name in order_keys:
|
| | return order_keys[name]
|
| | parts = name.split("_")
|
| | phase = parts[0]
|
| | idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| | phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
|
| | return phase_offset.get(phase, 300) + idx
|
| |
|
| | for name in sorted(act_spectra.keys(), key=sort_key):
|
| | data = act_spectra[name]
|
| | print(f" {name:<25} {data['effective_rank']:>8.1f} "
|
| | f"{data['top1_variance_ratio']*100:>7.1f}% "
|
| | f"{data['top10_variance_ratio']*100:>7.1f}%")
|
| |
|
| |
|
| | def save_results_json(results: dict, act_spectra: dict, output_path: Path):
|
| | """Save numerical results (no numpy arrays) to JSON."""
|
| | out = {}
|
| | for name, data in results.items():
|
| | out[name] = {k: v for k, v in data.items() if k != "singular_values"}
|
| | out[name]["top_10_sv"] = data["singular_values"][:10].tolist()
|
| |
|
| | if act_spectra:
|
| | out["_activations"] = {}
|
| | for name, data in act_spectra.items():
|
| | out["_activations"][name] = {k: v for k, v in data.items() if k != "eigenvalues"}
|
| | out["_activations"][name]["top_10_ev"] = data["eigenvalues"][:10].tolist()
|
| |
|
| | with open(output_path, "w") as f:
|
| | json.dump(out, f, indent=2, default=str)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_sample_data(data_source: str, tokenizer_name: str, num_samples: int = 256,
|
| | context_length: int = 512, device: str = "cpu"):
|
| | """Load a small batch of tokenized data for activation analysis."""
|
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| | from circuits.data import get_tokenizer
|
| |
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| |
|
| | if data_source.startswith("hf:"):
|
| | from datasets import load_dataset
|
| | parts = data_source[3:].split(":")
|
| | ds_name = parts[0]
|
| | ds_config = parts[1] if len(parts) > 1 else None
|
| | ds_split = parts[2] if len(parts) > 2 else "train"
|
| | dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
|
| | texts = []
|
| | for item in dataset:
|
| | texts.append(item.get("text", ""))
|
| | if len(texts) >= num_samples:
|
| | break
|
| | else:
|
| | with open(data_source) as f:
|
| | texts = [line.strip() for line in f if line.strip()][:num_samples]
|
| |
|
| |
|
| | all_ids = []
|
| | for text in texts:
|
| | ids = tokenizer.encode(text)
|
| | if len(ids) >= context_length:
|
| | all_ids.append(ids[:context_length])
|
| | elif len(ids) > 32:
|
| | all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
|
| |
|
| | if not all_ids:
|
| | return None, tokenizer
|
| |
|
| | input_ids = torch.tensor(all_ids[:num_samples], device=device)
|
| | return input_ids, tokenizer
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description="Spectral analysis of Prisma checkpoints")
|
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to Prisma/Circuit checkpoint")
|
| | parser.add_argument("--checkpoint-b", type=str, default=None, help="Second checkpoint for comparison")
|
| | parser.add_argument("--hf-model", type=str, default=None, help="HuggingFace model name for comparison")
|
| | parser.add_argument("--data", type=str, default=None,
|
| | help="Data source for activation analysis (hf:dataset:config:split or path)")
|
| | parser.add_argument("--num-samples", type=int, default=256, help="Number of samples for activation analysis")
|
| | parser.add_argument("--context-length", type=int, default=512, help="Context length for activation analysis")
|
| | parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: auto)")
|
| | parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| | parser.add_argument("--no-activations", action="store_true", help="Skip activation analysis even if data provided")
|
| | args = parser.parse_args()
|
| |
|
| | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| | print(f"Device: {device}")
|
| |
|
| |
|
| | if args.output_dir:
|
| | output_dir = Path(args.output_dir)
|
| | else:
|
| | ckpt_name = Path(args.checkpoint).parent.name
|
| | output_dir = Path("circuits/scripts/spectral_output") / ckpt_name
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| | print(f"Output: {output_dir}")
|
| |
|
| |
|
| | print(f"\nLoading: {args.checkpoint}")
|
| | model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
|
| | label_a = Path(args.checkpoint).parent.name
|
| | print(f" Type: {model_type_a}")
|
| | n_params = sum(p.numel() for p in model_a.parameters())
|
| | print(f" Parameters: {n_params:,}")
|
| |
|
| |
|
| | print("\nAnalyzing weight spectra...")
|
| | weight_results_a = analyze_weight_spectra(model_a, label_a)
|
| | print(f" Analyzed {len(weight_results_a)} weight matrices")
|
| |
|
| |
|
| | act_spectra_a = None
|
| | if args.data and not args.no_activations:
|
| | tokenizer_name = torch.load(args.checkpoint, map_location="cpu",
|
| | weights_only=False).get("tokenizer_name", "gpt2")
|
| | print(f"\nLoading data for activation analysis ({args.num_samples} samples)...")
|
| | input_ids, tokenizer = load_sample_data(
|
| | args.data, tokenizer_name, args.num_samples, args.context_length, device
|
| | )
|
| | if input_ids is not None:
|
| | print(f" Data shape: {input_ids.shape}")
|
| |
|
| |
|
| | word_positions = None
|
| | word_rope_dims = config_a.get("word_rope_dims", 0)
|
| | if word_rope_dims > 0:
|
| | from circuits.layers import build_word_start_table, compute_word_positions
|
| | word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| | word_positions = compute_word_positions(input_ids, word_start_table)
|
| |
|
| | print(" Collecting activations...")
|
| | if model_type_a == "mirrored":
|
| | raw_acts = collect_mirrored_activations(model_a, input_ids, word_positions)
|
| | else:
|
| | raw_acts = collect_activations(model_a, input_ids, word_positions, model_type_a)
|
| |
|
| | print(f" Computing activation spectra ({len(raw_acts)} layers)...")
|
| | act_spectra_a = {}
|
| | for name, act in raw_acts.items():
|
| | spec = activation_spectrum(act)
|
| | if spec is not None:
|
| | act_spectra_a[name] = spec
|
| |
|
| |
|
| | weight_results_b = None
|
| | label_b = None
|
| | if args.checkpoint_b:
|
| | print(f"\nLoading comparison: {args.checkpoint_b}")
|
| | model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
|
| | label_b = Path(args.checkpoint_b).parent.name
|
| | weight_results_b = analyze_weight_spectra(model_b, label_b)
|
| | del model_b
|
| | elif args.hf_model:
|
| | print(f"\nLoading HF model: {args.hf_model}")
|
| | model_b = load_hf_model(args.hf_model, device)
|
| | label_b = args.hf_model
|
| | weight_results_b = analyze_weight_spectra(model_b, label_b)
|
| | del model_b
|
| |
|
| | if device.startswith("cuda"):
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | print("\nGenerating plots...")
|
| | plot_weight_spectra(weight_results_a, output_dir, label_a)
|
| | plot_effective_rank_progression(weight_results_a, output_dir, label_a)
|
| | plot_gate_spectra(weight_results_a, output_dir, label_a)
|
| |
|
| | if act_spectra_a:
|
| | plot_activation_spectra(act_spectra_a, output_dir, label_a)
|
| | plot_mirror_comparison(act_spectra_a, output_dir, label_a)
|
| | plot_embedding_alignment(weight_results_a, act_spectra_a, output_dir, label_a)
|
| |
|
| | if weight_results_b and label_b:
|
| | plot_comparison(weight_results_a, weight_results_b, label_a, label_b, output_dir)
|
| |
|
| | print_summary(weight_results_b, label_b)
|
| |
|
| |
|
| | print_summary(weight_results_a, label_a, act_spectra_a)
|
| |
|
| |
|
| | save_results_json(weight_results_a, act_spectra_a, output_dir / "results.json")
|
| | if weight_results_b:
|
| | save_results_json(weight_results_b, None, output_dir / "results_b.json")
|
| |
|
| | print(f"\nAll outputs saved to: {output_dir}")
|
| | print(f" Plots: {len(list(output_dir.glob('*.png')))} PNG files")
|
| | print(f" Data: results.json")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|