Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Full production run: correlation analysis + plotting for all cached models. | |
| Runs multi-circuit analysis (QK, OV, biases, cross-correlations) for every | |
| model found in the cache directory, then generates all plots including | |
| multi-model comparisons. | |
| Usage: | |
| # Dry run — show what would be processed | |
| python scripts/run_all_correlations.py \ | |
| --cache /Volumes/Flux/Projects/transformer-analysis/downloads \ | |
| --dry-run | |
| # Full production run | |
| python scripts/run_all_correlations.py \ | |
| --cache /Volumes/Flux/Projects/transformer-analysis/downloads \ | |
| --out corr_out | |
| # Fast metrics only (skip KDE), no biases | |
| python scripts/run_all_correlations.py \ | |
| --cache /path/to/downloads --fast --no-bias | |
| # Add new metrics to an existing run without full re-extraction | |
| python scripts/run_all_correlations.py \ | |
| --cache /path/to/downloads --out corr_out \ | |
| --add-metrics hist_jensen_shannon hist_symmetric_kl | |
| """ | |
| import argparse | |
| import glob | |
| import json | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| import numpy as np | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) | |
| from transformer_analysis.model_registry import MODEL_CONFIGS | |
| from transformer_analysis.correlation_analysis import ( | |
| run_multi_circuit_analysis, | |
| find_cached_models, | |
| extract_head_store, | |
| ) | |
| from transformer_analysis.head_correlations import ( | |
| compute_correlation_matrices, | |
| correlation_summary, | |
| layer_block_means, | |
| ) | |
| # (n_layers, n_heads, head_dim) | |
| MODEL_DIMS = { | |
| "gpt2": (12, 12, 64), | |
| "gpt2-medium": (24, 16, 64), | |
| "gpt2-large": (36, 20, 64), | |
| "gpt2-xl": (48, 25, 64), | |
| "pythia-70m-deduped": (6, 8, 64), | |
| "pythia-160m-deduped": (12, 12, 64), | |
| "pythia-410m-deduped": (24, 16, 64), | |
| "pythia-1b-deduped": (16, 8, 128), | |
| "pythia-1.4b-deduped": (24, 16, 128), | |
| "pythia-2.8b-deduped": (32, 32, 80), | |
| "pythia-6.9b-deduped": (32, 32, 128), | |
| "pythia-12b-deduped": (36, 40, 128), | |
| "llama-3.1-8b": (32, 32, 128), | |
| "mistral-7b-v0.3": (32, 32, 128), | |
| } | |
| FAST_METRICS = ["frob_cosine", "two_point", "connected_corr", "pearson_corr"] | |
| HIST_METRICS = ["hist_symmetric_kl", "hist_jensen_shannon"] | |
| DEFAULT_METRICS = FAST_METRICS + HIST_METRICS | |
| def recompute_metrics_for_model(model_name, new_metrics, out_dir, cache_dir, device=None): | |
| """Add new metrics to all existing weight-type outputs for a model. | |
| Discovers weight types from metadata files in out_dir, skips metrics that | |
| are already present, re-extracts heads from cache, and merges new results | |
| into existing .npz and summary files. | |
| """ | |
| meta_paths = sorted( | |
| p for p in glob.glob(os.path.join(out_dir, f"{model_name}_*_metadata.json")) | |
| if "_vs_" not in os.path.basename(p) | |
| ) | |
| if not meta_paths: | |
| logging.warning(f" No existing outputs found for {model_name} in {out_dir}") | |
| return | |
| for meta_path in meta_paths: | |
| with open(meta_path) as f: | |
| meta = json.load(f) | |
| fname = os.path.basename(meta_path).replace("_metadata.json", "") | |
| model_rev_prefix = f"{model_name}_main_" | |
| if not fname.startswith(model_rev_prefix): | |
| logging.warning(f" Skipping {os.path.basename(meta_path)}: non-main revision") | |
| continue | |
| weight_type = fname[len(model_rev_prefix):] | |
| to_add = [m for m in new_metrics if m not in meta.get("metrics", [])] | |
| if not to_add: | |
| logging.info(f" {model_name}/{weight_type}: metrics already present, skipping") | |
| continue | |
| logging.info(f" {model_name}/{weight_type}: adding {to_add}") | |
| store, _ = extract_head_store( | |
| model_name=model_name, | |
| weight_type=weight_type, | |
| revision=None, | |
| cache_dir=cache_dir, | |
| device=device, | |
| ) | |
| Q_new = compute_correlation_matrices( | |
| store, | |
| metrics=tuple(to_add), | |
| kde_kwargs={"n_eval": 2048, "bw_method": "scott"}, | |
| show_progress=True, | |
| ) | |
| # Merge into existing .npz | |
| npz_path = os.path.join(out_dir, f"{fname}_Q.npz") | |
| existing = dict(np.load(npz_path)) if os.path.exists(npz_path) else {} | |
| for m, Q in Q_new.items(): | |
| existing[f"Q_{m}"] = Q | |
| np.savez_compressed(npz_path, **existing) | |
| # Update summary + per-metric arrays | |
| summary_path = os.path.join(out_dir, f"{fname}_summary.json") | |
| summary = json.load(open(summary_path)) if os.path.exists(summary_path) else {} | |
| for m, Q in Q_new.items(): | |
| s = correlation_summary(Q, store.keys) | |
| summary[m] = {k: v for k, v in s.items() if not isinstance(v, np.ndarray)} | |
| np.save(os.path.join(out_dir, f"{fname}_{m}_eigenvalues.npy"), s["eigenvalues"]) | |
| np.save(os.path.join(out_dir, f"{fname}_{m}_P_Q.npy"), s["P_Q_values"]) | |
| block, _ = layer_block_means(Q, store.keys) | |
| np.save(os.path.join(out_dir, f"{fname}_{m}_block_means.npy"), block) | |
| with open(summary_path, "w") as f: | |
| json.dump(summary, f, indent=2) | |
| # Update metadata | |
| for m in to_add: | |
| meta.setdefault("metrics", []).append(m) | |
| with open(meta_path, "w") as f: | |
| json.dump(meta, f, indent=2, default=str) | |
| def estimate_time_minutes(model_name, n_circuits=2, include_bias=True, fast_only=True): | |
| """Rough runtime estimate in minutes.""" | |
| if model_name not in MODEL_DIMS: | |
| return None | |
| n_l, n_h, d_h = MODEL_DIMS[model_name] | |
| N = n_l * n_h | |
| n_pairs = N * (N - 1) // 2 | |
| ref_pairs = 144 * 143 // 2 | |
| d2_ratio = (d_h ** 2) / (64 ** 2) | |
| # Extraction: ~2 min for gpt2-scale, scales with depth and width | |
| extract = 2.0 * (n_l / 12.0) * (d_h / 64.0) | |
| # Pair loop: ~0.5 min for gpt2 fast metrics per circuit | |
| pair = 0.5 * (n_pairs / ref_pairs) * d2_ratio | |
| if not fast_only: | |
| pair += 17.0 * (n_pairs / ref_pairs) * d2_ratio | |
| total = extract + pair * n_circuits | |
| if include_bias: | |
| total += pair * 4 # b_Q, b_K, b_V, b_O are cheaper but there are 4 | |
| return total | |
| def main(): | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") | |
| parser = argparse.ArgumentParser( | |
| description="Full production run: correlations + plots for all models") | |
| parser.add_argument("--cache", type=str, required=True, | |
| help="Model cache directory") | |
| parser.add_argument("--out", type=str, default="corr_out", | |
| help="Output directory for correlation results") | |
| parser.add_argument("--skip", nargs="*", default=[], | |
| help="Models to skip") | |
| parser.add_argument("--models", nargs="*", default=None, | |
| help="Only run these models (default: all cached)") | |
| parser.add_argument("--device", type=str, default=None, | |
| choices=["cuda", "mps", "cpu"]) | |
| parser.add_argument("--fast", action="store_true", | |
| help="Fast metrics only (skip histogram divergences)") | |
| parser.add_argument("--no-bias", action="store_true", | |
| help="Skip bias correlations") | |
| parser.add_argument("--no-cross", action="store_true", | |
| help="Skip cross-correlations") | |
| parser.add_argument("--no-plot", action="store_true", | |
| help="Skip plot generation") | |
| parser.add_argument("--config", type=str, default=None, | |
| help="Read detailed options from config") | |
| parser.add_argument("--dry-run", action="store_true", | |
| help="Print plan and time estimates without running") | |
| parser.add_argument("--add-metrics", nargs="+", default=None, | |
| help="Add metrics to existing outputs without full re-extraction") | |
| args = parser.parse_args() | |
| metrics = FAST_METRICS if args.fast else DEFAULT_METRICS | |
| circuits = ("QK", "OV") | |
| cross = () if args.no_cross else ("QKOV", "WB") | |
| include_bias = not args.no_bias | |
| to_run = None | |
| if args.config is not None: | |
| with open(args.config, "r") as f: | |
| config_opts = json.load(f) | |
| if 'metrics' in config_opts: | |
| metrics = config_opts['metrics'] | |
| if 'circuits' in config_opts: | |
| circuits = tuple(config_opts['circuits']) | |
| if 'cross' in config_opts: | |
| cross = tuple(config_opts['cross']) | |
| if 'models' in config_opts: | |
| to_run = config_opts['models'] | |
| if not to_run: | |
| # Discover models | |
| cached = find_cached_models(args.cache) | |
| if args.models: | |
| to_run = [m for m in args.models if m in cached] | |
| missing = [m for m in args.models if m not in cached] | |
| if missing: | |
| print(f"Warning: not cached: {missing}") | |
| else: | |
| to_run = [m for m in cached if m not in args.skip] | |
| if not to_run: | |
| print(f"No models to run. Cached: {cached}") | |
| return 1 | |
| # Summary table | |
| print("\n" + "=" * 70) | |
| print("PRODUCTION RUN PLAN") | |
| print(f" Circuits: {', '.join(circuits)}") | |
| print(f" Biases: {'yes' if include_bias else 'no'}") | |
| print(f" Cross: {', '.join(cross) if cross else 'none'}") | |
| print(f" Metrics: {', '.join(metrics)}") | |
| print(f" Output: {os.path.abspath(args.out)}") | |
| print("-" * 70) | |
| header = "{:<28} {:>8} {:>8} {:>12}".format("Model", "N_heads", "d_head", "Est. time") | |
| print(header) | |
| print("-" * 70) | |
| total_est = 0.0 | |
| for model in to_run: | |
| dims = MODEL_DIMS.get(model) | |
| if dims: | |
| n_l, n_h, d_h = dims | |
| N = n_l * n_h | |
| est = estimate_time_minutes(model, len(circuits), include_bias, args.fast) | |
| total_est += est | |
| time_str = "{:.0f} min".format(est) if est < 60 else "{:.1f} hr".format(est / 60) | |
| print(" {:<26} {:>8} {:>8} {:>12}".format(model, N, d_h, time_str)) | |
| else: | |
| print(" {:<26} {:>8} {:>8} {:>12}".format(model, "?", "?", "unknown")) | |
| print("-" * 70) | |
| total_str = "{:.0f} min".format(total_est) if total_est < 60 else "{:.1f} hr".format(total_est / 60) | |
| print(" {:<26} {:>8} {:>8} {:>12}".format("TOTAL", "", "", total_str)) | |
| print("=" * 70 + "\n") | |
| if args.dry_run: | |
| print("Dry run — exiting.") | |
| return 0 | |
| # Run analysis | |
| os.makedirs(args.out, exist_ok=True) | |
| results = {} | |
| for i, model in enumerate(to_run): | |
| print("\n" + "=" * 60) | |
| print("[{}/{}] {}".format(i + 1, len(to_run), model)) | |
| print("=" * 60) | |
| t0 = time.time() | |
| try: | |
| if args.add_metrics: | |
| recompute_metrics_for_model( | |
| model_name=model, | |
| new_metrics=args.add_metrics, | |
| out_dir=args.out, | |
| cache_dir=args.cache, | |
| device=args.device, | |
| ) | |
| else: | |
| run_multi_circuit_analysis( | |
| model_name=model, | |
| circuits=circuits, | |
| include_bias=include_bias, | |
| cross_correlations=cross, | |
| metrics=tuple(metrics), | |
| cache_dir=args.cache, | |
| out_dir=args.out, | |
| device=args.device, | |
| ) | |
| elapsed = (time.time() - t0) / 60 | |
| results[model] = ("OK", elapsed) | |
| print(" Completed in {:.1f} min".format(elapsed)) | |
| except Exception as e: | |
| elapsed = (time.time() - t0) / 60 | |
| results[model] = ("FAILED", elapsed) | |
| logging.error("FAILED: {} — {}".format(model, e)) | |
| import traceback | |
| traceback.print_exc() | |
| # Plot all | |
| if not args.no_plot: | |
| print("\n" + "=" * 60) | |
| print("PLOTTING") | |
| print("=" * 60) | |
| subprocess.run([ | |
| sys.executable, os.path.join(os.path.dirname(__file__), "plot_corr_figures.py"), | |
| "--data", args.out, | |
| ]) | |
| # Final summary | |
| print("\n" + "=" * 60) | |
| print("SUMMARY") | |
| print("-" * 60) | |
| for model, (status, elapsed) in results.items(): | |
| print(" {:<30} {:>8} {:>8.1f} min".format(model, status, elapsed)) | |
| n_ok = sum(1 for s, _ in results.values() if s == "OK") | |
| n_fail = sum(1 for s, _ in results.values() if s != "OK") | |
| print("-" * 60) | |
| print(" {} succeeded, {} failed".format(n_ok, n_fail)) | |
| print(" Output: {}".format(os.path.abspath(args.out))) | |
| print("=" * 60) | |
| return 0 if n_fail == 0 else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |