Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Plot correlations for all models found in a data directory. | |
| Automatically discovers models from metadata files and generates plots for each. | |
| Similar to run_all_correlations.py but for plotting instead of analysis. | |
| Usage: | |
| python scripts/plot_all_models.py --data corr_out | |
| python scripts/plot_all_models.py --data corr_out --skip gpt2 | |
| python scripts/plot_all_models.py --data corr_out --models gpt2 gpt2-large | |
| python scripts/plot_all_models.py --data corr_out --components weights # or biases | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from collections import defaultdict | |
| def find_models_and_components(data_dir): | |
| """Find all models and their components from metadata files. | |
| Returns: | |
| dict: {model_name: [(revision, component), ...]} | |
| where component is like 'W_QK', 'W_OV', 'b_Q', etc. | |
| """ | |
| data_path = Path(data_dir) | |
| if not data_path.exists(): | |
| return {} | |
| models = defaultdict(list) | |
| # Pattern: {model}_{revision}_{component}_metadata.json | |
| # Examples: | |
| # gpt2_main_W_QK_metadata.json | |
| # gpt2_main_b_Q_metadata.json | |
| # pythia-70m-deduped_main_W_OV_metadata.json | |
| for metadata_file in data_path.glob("*_metadata.json"): | |
| filename = metadata_file.name | |
| # Skip cross-correlation files (contain _vs_) | |
| if "_vs_" in filename: | |
| continue | |
| # Parse filename | |
| parts = filename.replace("_metadata.json", "").split("_") | |
| # Find the component (W_QK, W_OV, b_Q, etc.) | |
| component = None | |
| for i, part in enumerate(parts): | |
| if part in ["W", "b"] and i + 1 < len(parts): | |
| component = f"{part}_{parts[i + 1]}" | |
| component_idx = i | |
| break | |
| if not component: | |
| continue | |
| # Everything before component is model + revision | |
| model_revision = "_".join(parts[:component_idx]) | |
| # Last part before component is usually revision | |
| if parts[component_idx - 1] in ["main", "step0", "step1000"]: | |
| revision = parts[component_idx - 1] | |
| model = "_".join(parts[:component_idx - 1]) | |
| else: | |
| revision = "main" | |
| model = model_revision | |
| models[model].append((revision, component)) | |
| return dict(models) | |
| def categorize_component(component): | |
| """Categorize component as 'weight' or 'bias'.""" | |
| if component.startswith("W_"): | |
| return "weight" | |
| elif component.startswith("b_"): | |
| return "bias" | |
| return "unknown" | |
| def plot_model_component(data_dir, model, revision, component, out_dir, quiet=False): | |
| """Run plot_correlations.py for a specific model/component.""" | |
| # Determine weight_type parameter (legacy parameter name) | |
| weight_type = component | |
| cmd = [ | |
| sys.executable, | |
| "scripts/plot_correlations.py", | |
| "--data", data_dir, | |
| "--model", model, | |
| "--revision", revision, | |
| "--weight-type", weight_type, | |
| "--out", out_dir, | |
| ] | |
| if not quiet: | |
| print(f" Plotting: {model} @ {revision} - {component}") | |
| try: | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=quiet, | |
| text=True, | |
| check=True | |
| ) | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| if not quiet: | |
| print(f" ERROR: {e}") | |
| if e.stderr: | |
| print(f" {e.stderr}") | |
| return False | |
| except Exception as e: | |
| if not quiet: | |
| print(f" ERROR: {e}") | |
| return False | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Plot correlations for all models in data directory", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Plot all models | |
| python scripts/plot_all_models.py --data corr_out | |
| # Plot specific models only | |
| python scripts/plot_all_models.py --data corr_out --models gpt2 gpt2-large | |
| # Skip certain models | |
| python scripts/plot_all_models.py --data corr_out --skip gpt2 | |
| # Plot only weights (no biases) | |
| python scripts/plot_all_models.py --data corr_out --components weights | |
| # Plot only biases | |
| python scripts/plot_all_models.py --data corr_out --components biases | |
| # Quiet mode (less output) | |
| python scripts/plot_all_models.py --data corr_out --quiet | |
| """ | |
| ) | |
| parser.add_argument( | |
| "--data", type=str, default="corr_out", | |
| help="Data directory containing correlation results (default: corr_out)" | |
| ) | |
| parser.add_argument( | |
| "--out", type=str, default=None, | |
| help="Output directory for figures (default: {data}/figures)" | |
| ) | |
| parser.add_argument( | |
| "--models", nargs="*", default=None, | |
| help="Specific models to plot (default: all found models)" | |
| ) | |
| parser.add_argument( | |
| "--skip", nargs="*", default=[], | |
| help="Models to skip" | |
| ) | |
| parser.add_argument( | |
| "--components", choices=["weights", "biases", "all"], default="all", | |
| help="Which components to plot (default: all)" | |
| ) | |
| parser.add_argument( | |
| "--quiet", "-q", action="store_true", | |
| help="Suppress detailed output" | |
| ) | |
| parser.add_argument( | |
| "--dry-run", action="store_true", | |
| help="Show what would be plotted without plotting" | |
| ) | |
| parser.add_argument( | |
| "--build-dataset", type=str, default=None, metavar="REPO", | |
| help="After plotting, build and push HF dataset " | |
| "(e.g. user/transformer-analysis-figures)" | |
| ) | |
| args = parser.parse_args() | |
| # Default output directory | |
| out_dir = args.out or os.path.join(args.data, "figures") | |
| # Find models | |
| if not args.quiet: | |
| print(f"Scanning directory: {args.data}") | |
| models_components = find_models_and_components(args.data) | |
| if not models_components: | |
| print(f"No models found in {args.data}") | |
| print("Make sure the directory contains *_metadata.json files") | |
| return 1 | |
| # Filter models | |
| if args.models: | |
| models_to_plot = { | |
| m: c for m, c in models_components.items() | |
| if m in args.models | |
| } | |
| else: | |
| models_to_plot = models_components | |
| # Skip models | |
| if args.skip: | |
| models_to_plot = { | |
| m: c for m, c in models_to_plot.items() | |
| if m not in args.skip | |
| } | |
| if not models_to_plot: | |
| print("No models to plot after filtering") | |
| return 1 | |
| # Count components | |
| total_components = sum(len(components) for components in models_to_plot.values()) | |
| # Filter by component type | |
| if args.components != "all": | |
| # Map plural to singular | |
| component_type_map = { | |
| "weights": "weight", | |
| "biases": "bias" | |
| } | |
| component_type = component_type_map.get(args.components, args.components) | |
| filtered_models = {} | |
| for model, components in models_to_plot.items(): | |
| filtered = [ | |
| (rev, comp) for rev, comp in components | |
| if categorize_component(comp) == component_type | |
| ] | |
| if filtered: | |
| filtered_models[model] = filtered | |
| models_to_plot = filtered_models | |
| # Recount after filtering | |
| filtered_components = sum(len(components) for components in models_to_plot.values()) | |
| # Print summary | |
| print("\n" + "=" * 70) | |
| print(f"Found {len(models_to_plot)} models with {filtered_components} components:") | |
| print("-" * 70) | |
| for model, components in sorted(models_to_plot.items()): | |
| if components: | |
| comp_strs = [] | |
| for rev, comp in sorted(set(components)): | |
| comp_type = "W" if comp.startswith("W_") else "b" | |
| comp_strs.append(f"{comp}") | |
| print(f" {model:<30} {len(components):>2} components: {', '.join(sorted(set(comp_strs)))}") | |
| print("=" * 70) | |
| if args.dry_run: | |
| print("\nDry run - exiting without plotting") | |
| return 0 | |
| # Create output directory | |
| os.makedirs(out_dir, exist_ok=True) | |
| # Plot each model/component | |
| print(f"\nOutput directory: {out_dir}\n") | |
| success_count = 0 | |
| fail_count = 0 | |
| for i, (model, components) in enumerate(sorted(models_to_plot.items()), 1): | |
| if not components: | |
| continue | |
| print(f"[{i}/{len(models_to_plot)}] {model}") | |
| for revision, component in sorted(set(components)): | |
| if plot_model_component( | |
| args.data, model, revision, component, out_dir, args.quiet | |
| ): | |
| success_count += 1 | |
| else: | |
| fail_count += 1 | |
| # Summary | |
| print("\n" + "=" * 70) | |
| print(f"Plotting complete!") | |
| print(f" Success: {success_count}") | |
| print(f" Failed: {fail_count}") | |
| print(f" Output: {out_dir}") | |
| print("=" * 70) | |
| # Multi-model comparison plots | |
| if success_count > 1: | |
| print("\nGenerating multi-model comparison plots...") | |
| try: | |
| from plot_correlations import (plot_eigenvalue_comparison, | |
| plot_eigen_stats_comparison) | |
| model_list = sorted(models_to_plot.keys()) | |
| # One comparison plot per weight type that all models share | |
| all_wts = set() | |
| for components in models_to_plot.values(): | |
| for _, comp in components: | |
| all_wts.add(comp) | |
| for wt in sorted(all_wts): | |
| try: | |
| plot_eigenvalue_comparison( | |
| args.data, model_list, weight_type=wt, | |
| out_dir=out_dir) | |
| except Exception as e: | |
| print(f" *** Error on {wt} eigenvalues: {e}") | |
| try: | |
| plot_eigen_stats_comparison( | |
| args.data, model_list, weight_type=wt, | |
| out_dir=out_dir) | |
| except Exception as e: | |
| print(f" *** Error on {wt} eigen stats: {e}") | |
| except ImportError as e: | |
| print(f" *** Could not import comparison plotter: {e}") | |
| # Build HF dataset if requested | |
| if args.build_dataset and success_count > 0: | |
| print(f"\nBuilding HF dataset → {args.build_dataset}") | |
| try: | |
| from build_hf_dataset import build_dataset | |
| ds = build_dataset(out_dir) | |
| ds.push_to_hub(args.build_dataset) | |
| print(f"Pushed: https://huggingface.co/datasets/{args.build_dataset}") | |
| except Exception as e: | |
| print(f" *** Dataset build failed: {e}") | |
| # Reminder to regenerate viewer | |
| if success_count > 0: | |
| print("\nTo view in browser, regenerate the viewer index:") | |
| print(f" python scripts/generate_viewer_index.py --out {args.data} --serve") | |
| print("=" * 70) | |
| return 0 if fail_count == 0 else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |