File size: 11,016 Bytes
eb2f1cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
#!/usr/bin/env python3
"""Plot correlations for all models found in a data directory.

Automatically discovers models from metadata files and generates plots for each.
Similar to run_all_correlations.py but for plotting instead of analysis.

Usage:
    python scripts/plot_all_models.py --data corr_out
    python scripts/plot_all_models.py --data corr_out --skip gpt2
    python scripts/plot_all_models.py --data corr_out --models gpt2 gpt2-large
    python scripts/plot_all_models.py --data corr_out --components weights  # or biases
"""

import argparse
import json
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from collections import defaultdict


def find_models_and_components(data_dir):
    """Find all models and their components from metadata files.

    Returns:
        dict: {model_name: [(revision, component), ...]}
              where component is like 'W_QK', 'W_OV', 'b_Q', etc.
    """
    data_path = Path(data_dir)
    if not data_path.exists():
        return {}

    models = defaultdict(list)

    # Pattern: {model}_{revision}_{component}_metadata.json
    # Examples:
    #   gpt2_main_W_QK_metadata.json
    #   gpt2_main_b_Q_metadata.json
    #   pythia-70m-deduped_main_W_OV_metadata.json

    for metadata_file in data_path.glob("*_metadata.json"):
        filename = metadata_file.name

        # Skip cross-correlation files (contain _vs_)
        if "_vs_" in filename:
            continue

        # Parse filename
        parts = filename.replace("_metadata.json", "").split("_")

        # Find the component (W_QK, W_OV, b_Q, etc.)
        component = None
        for i, part in enumerate(parts):
            if part in ["W", "b"] and i + 1 < len(parts):
                component = f"{part}_{parts[i + 1]}"
                component_idx = i
                break

        if not component:
            continue

        # Everything before component is model + revision
        model_revision = "_".join(parts[:component_idx])

        # Last part before component is usually revision
        if parts[component_idx - 1] in ["main", "step0", "step1000"]:
            revision = parts[component_idx - 1]
            model = "_".join(parts[:component_idx - 1])
        else:
            revision = "main"
            model = model_revision

        models[model].append((revision, component))

    return dict(models)


def categorize_component(component):
    """Categorize component as 'weight' or 'bias'."""
    if component.startswith("W_"):
        return "weight"
    elif component.startswith("b_"):
        return "bias"
    return "unknown"


def plot_model_component(data_dir, model, revision, component, out_dir, quiet=False):
    """Run plot_correlations.py for a specific model/component."""
    # Determine weight_type parameter (legacy parameter name)
    weight_type = component

    cmd = [
        sys.executable,
        "scripts/plot_correlations.py",
        "--data", data_dir,
        "--model", model,
        "--revision", revision,
        "--weight-type", weight_type,
        "--out", out_dir,
    ]

    if not quiet:
        print(f"  Plotting: {model} @ {revision} - {component}")

    try:
        result = subprocess.run(
            cmd,
            capture_output=quiet,
            text=True,
            check=True
        )
        return True
    except subprocess.CalledProcessError as e:
        if not quiet:
            print(f"    ERROR: {e}")
            if e.stderr:
                print(f"    {e.stderr}")
        return False
    except Exception as e:
        if not quiet:
            print(f"    ERROR: {e}")
        return False


def main():
    parser = argparse.ArgumentParser(
        description="Plot correlations for all models in data directory",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Plot all models
  python scripts/plot_all_models.py --data corr_out

  # Plot specific models only
  python scripts/plot_all_models.py --data corr_out --models gpt2 gpt2-large

  # Skip certain models
  python scripts/plot_all_models.py --data corr_out --skip gpt2

  # Plot only weights (no biases)
  python scripts/plot_all_models.py --data corr_out --components weights

  # Plot only biases
  python scripts/plot_all_models.py --data corr_out --components biases

  # Quiet mode (less output)
  python scripts/plot_all_models.py --data corr_out --quiet
        """
    )

    parser.add_argument(
        "--data", type=str, default="corr_out",
        help="Data directory containing correlation results (default: corr_out)"
    )
    parser.add_argument(
        "--out", type=str, default=None,
        help="Output directory for figures (default: {data}/figures)"
    )
    parser.add_argument(
        "--models", nargs="*", default=None,
        help="Specific models to plot (default: all found models)"
    )
    parser.add_argument(
        "--skip", nargs="*", default=[],
        help="Models to skip"
    )
    parser.add_argument(
        "--components", choices=["weights", "biases", "all"], default="all",
        help="Which components to plot (default: all)"
    )
    parser.add_argument(
        "--quiet", "-q", action="store_true",
        help="Suppress detailed output"
    )
    parser.add_argument(
        "--dry-run", action="store_true",
        help="Show what would be plotted without plotting"
    )
    parser.add_argument(
        "--build-dataset", type=str, default=None, metavar="REPO",
        help="After plotting, build and push HF dataset "
             "(e.g. user/transformer-analysis-figures)"
    )

    args = parser.parse_args()

    # Default output directory
    out_dir = args.out or os.path.join(args.data, "figures")

    # Find models
    if not args.quiet:
        print(f"Scanning directory: {args.data}")

    models_components = find_models_and_components(args.data)

    if not models_components:
        print(f"No models found in {args.data}")
        print("Make sure the directory contains *_metadata.json files")
        return 1

    # Filter models
    if args.models:
        models_to_plot = {
            m: c for m, c in models_components.items()
            if m in args.models
        }
    else:
        models_to_plot = models_components

    # Skip models
    if args.skip:
        models_to_plot = {
            m: c for m, c in models_to_plot.items()
            if m not in args.skip
        }

    if not models_to_plot:
        print("No models to plot after filtering")
        return 1

    # Count components
    total_components = sum(len(components) for components in models_to_plot.values())

    # Filter by component type
    if args.components != "all":
        # Map plural to singular
        component_type_map = {
            "weights": "weight",
            "biases": "bias"
        }
        component_type = component_type_map.get(args.components, args.components)

        filtered_models = {}
        for model, components in models_to_plot.items():
            filtered = [
                (rev, comp) for rev, comp in components
                if categorize_component(comp) == component_type
            ]
            if filtered:
                filtered_models[model] = filtered
        models_to_plot = filtered_models

    # Recount after filtering
    filtered_components = sum(len(components) for components in models_to_plot.values())

    # Print summary
    print("\n" + "=" * 70)
    print(f"Found {len(models_to_plot)} models with {filtered_components} components:")
    print("-" * 70)

    for model, components in sorted(models_to_plot.items()):
        if components:
            comp_strs = []
            for rev, comp in sorted(set(components)):
                comp_type = "W" if comp.startswith("W_") else "b"
                comp_strs.append(f"{comp}")

            print(f"  {model:<30} {len(components):>2} components: {', '.join(sorted(set(comp_strs)))}")

    print("=" * 70)

    if args.dry_run:
        print("\nDry run - exiting without plotting")
        return 0

    # Create output directory
    os.makedirs(out_dir, exist_ok=True)

    # Plot each model/component
    print(f"\nOutput directory: {out_dir}\n")

    success_count = 0
    fail_count = 0

    for i, (model, components) in enumerate(sorted(models_to_plot.items()), 1):
        if not components:
            continue

        print(f"[{i}/{len(models_to_plot)}] {model}")

        for revision, component in sorted(set(components)):
            if plot_model_component(
                args.data, model, revision, component, out_dir, args.quiet
            ):
                success_count += 1
            else:
                fail_count += 1

    # Summary
    print("\n" + "=" * 70)
    print(f"Plotting complete!")
    print(f"  Success: {success_count}")
    print(f"  Failed:  {fail_count}")
    print(f"  Output:  {out_dir}")
    print("=" * 70)

    # Multi-model comparison plots
    if success_count > 1:
        print("\nGenerating multi-model comparison plots...")
        try:
            from plot_correlations import (plot_eigenvalue_comparison,
                                            plot_eigen_stats_comparison)
            model_list = sorted(models_to_plot.keys())
            # One comparison plot per weight type that all models share
            all_wts = set()
            for components in models_to_plot.values():
                for _, comp in components:
                    all_wts.add(comp)
            for wt in sorted(all_wts):
                try:
                    plot_eigenvalue_comparison(
                        args.data, model_list, weight_type=wt,
                        out_dir=out_dir)
                except Exception as e:
                    print(f"  *** Error on {wt} eigenvalues: {e}")
                try:
                    plot_eigen_stats_comparison(
                        args.data, model_list, weight_type=wt,
                        out_dir=out_dir)
                except Exception as e:
                    print(f"  *** Error on {wt} eigen stats: {e}")
        except ImportError as e:
            print(f"  *** Could not import comparison plotter: {e}")

    # Build HF dataset if requested
    if args.build_dataset and success_count > 0:
        print(f"\nBuilding HF dataset → {args.build_dataset}")
        try:
            from build_hf_dataset import build_dataset
            ds = build_dataset(out_dir)
            ds.push_to_hub(args.build_dataset)
            print(f"Pushed: https://huggingface.co/datasets/{args.build_dataset}")
        except Exception as e:
            print(f"  *** Dataset build failed: {e}")

    # Reminder to regenerate viewer
    if success_count > 0:
        print("\nTo view in browser, regenerate the viewer index:")
        print(f"  python scripts/generate_viewer_index.py --out {args.data} --serve")
        print("=" * 70)

    return 0 if fail_count == 0 else 1


if __name__ == "__main__":
    sys.exit(main())