#!/usr/bin/env python3 """ Weight Analysis Script - Flexible harness for transformer weight analysis This script supports three modes of operation: 1. Single model, single revision (e.g., gpt2 at main) 2. Single model, all revisions (e.g., pythia-70m-deduped across all checkpoints) 3. Config file mode (process multiple models/revisions from JSON config) Examples: # Single model, single revision python run_weight_analysis.py --model gpt2 --revision main # Single model, all revisions python run_weight_analysis.py --model pythia-70m-deduped --all-revisions # Process all revisions and merge into single dataset python run_weight_analysis.py --model pythia-70m-deduped --all-revisions --merge # Config file mode python run_weight_analysis.py --config sweep_config.json # List available models python run_weight_analysis.py --list-models """ import os import sys import json import argparse from pathlib import Path from typing import List, Optional, Dict, Any from tqdm import tqdm from transformers import logging as hf_logging import warnings from transformer_analysis.weight_analysis import process_model, create_campaign, merge_versions from transformer_analysis.model_registry import MODEL_CONFIGS, list_supported_models, get_model_config def process_single_model( model_name: str, revision: Optional[str] = None, all_revisions: bool = False, out_dir: str = "outputs", cache_dir: str = "./model_data", clobber: bool = False, cleanup_downloads: bool = False, low_rank_svd: bool = False, top_k_svd: int = -1, resume_download: bool = True, max_workers: int = 4, device: Optional[str] = None, quiet: bool = False, merge: bool = False, merge_suffix: str = "all_checkpoints", skip_postprocess: bool = False, binning_strategy: str = "fixed", ): if not quiet: print("\n" + "=" * 80) print(f"Processing Model: {model_name}") print("=" * 80) try: model_config = get_model_config(model_name) except ValueError as e: print(f"ERROR: {e}") return # Determine which revisions to process if all_revisions: revisions = model_config.revisions if not revisions: print(f"Model {model_name} has no revisions defined. Processing main branch only.") revisions = [None] elif revision: revisions = [revision] else: revisions = [None] # Main/latest branch if not quiet: print(f"Revisions to process: {len(revisions)}") if len(revisions) <= 10: print(f" {revisions}") # Process each revision for rev in tqdm(revisions, desc=f"Processing {model_name}", disable=quiet): revision_str = rev if rev else "main" # Check if output already exists if rev: target_dir = os.path.join(out_dir, f"{model_name}_{revision_str}") else: target_dir = os.path.join(out_dir, model_name) if os.path.exists(target_dir) and not clobber: if not quiet: print(f" Skipping {model_name} @ {revision_str} (output exists)") continue if not quiet: print(f"\n Processing: {model_name} @ {revision_str}") try: process_model( model_name=model_name, revision=rev, out_dir=out_dir, cache_dir=cache_dir, cleanup_downloads=cleanup_downloads, low_rank_svd_approximation=low_rank_svd, top_k_svd=top_k_svd, resume_download=resume_download, max_workers=max_workers, device=device, skip_postprocess=skip_postprocess, binning_strategy=binning_strategy, ) except Exception as e: print(f" ERROR processing {model_name} @ {revision_str}: {e}") continue if not quiet: print("\n" + "=" * 80) print(f"Completed: {model_name}") print("=" * 80 + "\n") # Merge datasets if requested and multiple revisions were processed if merge and len(revisions) > 1: if not quiet: print("\n" + "=" * 80) print(f"Merging datasets for: {model_name}") print("=" * 80) try: merge_versions( model_name=model_name, path=out_dir, suffix=merge_suffix, ) if not quiet: print(f"Successfully merged to: {out_dir}/{model_name}_{merge_suffix}") print("=" * 80 + "\n") except Exception as e: print(f"ERROR merging datasets: {e}") if not quiet: print("=" * 80 + "\n") def process_from_config( config_path: str, out_dir: Optional[str] = None, cache_dir: Optional[str] = None, clobber: Optional[bool] = None, cleanup_downloads: Optional[bool] = None, low_rank_svd: Optional[bool] = None, top_k_svd: Optional[int] = None, resume_download: Optional[bool] = None, max_workers: Optional[int] = None, device: Optional[str] = None, quiet: bool = False, ): if not os.path.exists(config_path): print(f"ERROR: Config file not found: {config_path}") return try: with open(config_path, 'r') as f: config = json.load(f) except json.JSONDecodeError as e: print(f"ERROR: Invalid JSON in config file: {e}") return # Extract global settings (with CLI overrides) global_out_dir = out_dir or config.get("output_dir", "outputs") global_cache_dir = cache_dir or config.get("cache_dir", "./model_data") global_clobber = clobber if clobber is not None else config.get("clobber", False) global_cleanup = cleanup_downloads if cleanup_downloads is not None else config.get("cleanup_downloads", False) global_low_rank_svd = low_rank_svd if low_rank_svd is not None else config.get("low_rank_svd", False) global_top_k_svd = top_k_svd if top_k_svd is not None else config.get("top_k_svd", -1) global_resume_download = resume_download if resume_download is not None else config.get("resume_download", True) global_max_workers = max_workers if max_workers is not None else config.get("max_workers", 4) global_device = device if device is not None else config.get("device", None) if not quiet: print("\n" + "=" * 80) print(f"Processing from config: {config_path}") print("=" * 80) print(f"Output directory: {global_out_dir}") print(f"Cache directory: {global_cache_dir}") print(f"Clobber: {global_clobber}") print(f"Cleanup downloads: {global_cleanup}") print(f"Low-rank SVD: {global_low_rank_svd}") if global_low_rank_svd: print(f"Top-k SVD: {global_top_k_svd}") print(f"Resume downloads: {global_resume_download}") print(f"Max workers: {global_max_workers}") print("=" * 80 + "\n") # Process models models = config.get("models", []) if not models: print("ERROR: No models specified in config file") return for model_spec in models: if isinstance(model_spec, str): # Simple format: just model name model_name = model_spec revisions = None all_revs = False model_low_rank = global_low_rank_svd model_top_k = global_top_k_svd elif isinstance(model_spec, dict): # Detailed format model_name = model_spec.get("name") if not model_name: print("ERROR: Model specification missing 'name' field") continue revisions = model_spec.get("revisions") all_revs = model_spec.get("all_revisions", False) model_low_rank = model_spec.get("low_rank_svd", global_low_rank_svd) model_top_k = model_spec.get("top_k_svd", global_top_k_svd) else: print(f"ERROR: Invalid model specification: {model_spec}") continue # Process based on specification if revisions: # Specific revisions listed for rev in revisions: process_single_model( model_name=model_name, revision=rev, all_revisions=False, out_dir=global_out_dir, cache_dir=global_cache_dir, clobber=global_clobber, cleanup_downloads=global_cleanup, low_rank_svd=model_low_rank, top_k_svd=model_top_k, resume_download=global_resume_download, max_workers=global_max_workers, device=global_device, quiet=quiet, ) else: # Process with all_revisions flag process_single_model( model_name=model_name, revision=None, all_revisions=all_revs, out_dir=global_out_dir, cache_dir=global_cache_dir, clobber=global_clobber, cleanup_downloads=global_cleanup, low_rank_svd=model_low_rank, top_k_svd=model_top_k, resume_download=global_resume_download, max_workers=global_max_workers, device=global_device, quiet=quiet, ) def model_size_from_name(model_name: str) -> float: """Extract model size for sorting (in millions of parameters).""" import re # Extract size like "70m", "1.4b", "12b" match = re.search(r"(\d+\.?\d*)([mb])", model_name.lower()) if not match: return 0 size, unit = match.groups() size = float(size) # Convert to millions for consistent comparison if unit == "b": size *= 1000 return size def list_models(): models = list_supported_models() print("\nAvailable models:") print("=" * 80) # Group by model family families = {} for model in models: if "pythia" in model: family = "Pythia" elif "gpt2" in model: family = "GPT-2" elif "llama" in model: family = "LLaMA" elif "mistral" in model or "mixtral" in model: family = "Mistral/Mixtral" else: family = "Other" if family not in families: families[family] = [] families[family].append(model) for family, family_models in sorted(families.items()): print(f"\n{family}:") # Sort models within family by size for model in sorted(family_models, key=model_size_from_name): model_config = get_model_config(model) n_revisions = len(model_config.revisions) rev_info = f"({n_revisions} revisions)" if n_revisions > 0 else "(no revisions)" print(f" - {model} {rev_info}") print("\n" + "=" * 80) print(f"Total models: {len(models)}") print("=" * 80 + "\n") def main(): parser = argparse.ArgumentParser( description="Weight Analysis Script - Flexible harness for transformer weight analysis", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Single model, main revision %(prog)s --model gpt2 # Single model, specific revision %(prog)s --model pythia-70m-deduped --revision step1000 # Single model, all revisions %(prog)s --model pythia-70m-deduped --all-revisions # Process all revisions and merge into single dataset %(prog)s --model pythia-70m-deduped --all-revisions --merge # Config file mode %(prog)s --config sweep_config.json # With low-rank SVD %(prog)s --model gpt2 --low-rank-svd --top-k-svd 64 # List available models %(prog)s --list-models """ ) # Mode selection mode_group = parser.add_mutually_exclusive_group(required=True) mode_group.add_argument("--model", type=str, help="Model name to process") mode_group.add_argument("--config", type=str, help="Path to JSON config file") mode_group.add_argument("--list-models", action="store_true", help="List all available models") # Model-specific options parser.add_argument("--revision", type=str, help="Specific revision/checkpoint to process") parser.add_argument("--all-revisions", action="store_true", help="Process all available revisions for the model") # Directory options parser.add_argument("--out", type=str, default="outputs", help="Output directory (default: outputs)") parser.add_argument("--cache", type=str, default="./model_data", help="Cache directory for downloads (default: ./model_data)") # Processing options parser.add_argument("--clobber", action="store_true", help="Overwrite existing outputs") parser.add_argument("--cleanup", action="store_true", help="Clean up downloaded models after processing") parser.add_argument("--merge", action="store_true", help="Merge all revisions into a single dataset after processing (only with --all-revisions)") parser.add_argument("--merge-suffix", type=str, default="all_checkpoints", help="Suffix for merged dataset (default: all_checkpoints)") parser.add_argument("--low-rank-svd", action="store_true") parser.add_argument("--top-k-svd", type=int, default=-1) parser.add_argument("--binning-strategy", type=str, default="fixed", choices=["fixed", "scott", "fd"], help="Histogram binning strategy: fixed linspace (default), Scott's rule, or Freedman-Diaconis") parser.add_argument("--resume-download", action="store_true", default=True, dest="resume_download") parser.add_argument("--no-resume-download", action="store_false", dest="resume_download") parser.add_argument("--max-workers", type=int, default=4, dest="max_workers") parser.add_argument("--device", type=str, default=None, choices=["cuda", "mps", "cpu"]) parser.add_argument("--quiet", "-q", action="store_true") # Post-processing options parser.add_argument("--no-postprocess", action="store_true", help="Skip post-processing metrics computation") parser.add_argument("--reprocess-metrics", action="store_true", help="Reprocess existing datasets to update/add metrics without re-running full analysis") args = parser.parse_args() # Suppress warnings unless verbose if args.quiet: hf_logging.set_verbosity_error() warnings.filterwarnings("ignore") # Handle list-models mode if args.list_models: list_models() return # Handle config file mode if args.config: process_from_config( config_path=args.config, out_dir=args.out if args.out != "outputs" else None, cache_dir=args.cache if args.cache != "./model_data" else None, clobber=args.clobber if args.clobber else None, cleanup_downloads=args.cleanup if args.cleanup else None, low_rank_svd=args.low_rank_svd if args.low_rank_svd else None, top_k_svd=args.top_k_svd if args.top_k_svd != -1 else None, resume_download=args.resume_download, max_workers=args.max_workers, device=args.device, quiet=args.quiet, ) return # Handle single model mode if args.model: # Check for reprocess-metrics mode if args.reprocess_metrics: # Import the reprocessing function from transformer_analysis.weight_analysis import reprocess_metrics reprocess_metrics( model_name=args.model, revision=args.revision, all_revisions=args.all_revisions, out_dir=args.out, quiet=args.quiet, ) return # Suppress warnings for cleaner output hf_logging.set_verbosity_error() warnings.filterwarnings("ignore") process_single_model( model_name=args.model, revision=args.revision, all_revisions=args.all_revisions, out_dir=args.out, cache_dir=args.cache, clobber=args.clobber, cleanup_downloads=args.cleanup, low_rank_svd=args.low_rank_svd, top_k_svd=args.top_k_svd, resume_download=args.resume_download, max_workers=args.max_workers, device=args.device, quiet=args.quiet, merge=args.merge, merge_suffix=args.merge_suffix, skip_postprocess=args.no_postprocess, binning_strategy=args.binning_strategy, ) if __name__ == "__main__": main()