File size: 5,028 Bytes
5fbb1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
 
 
 
5fbb1fb
230508d
 
 
5fbb1fb
 
 
 
 
 
 
 
 
230508d
5fbb1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
5fbb1fb
 
 
 
 
 
 
 
 
 
 
 
87b2fa6
 
 
 
 
5fbb1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""Run linear probes on all trained checkpoints and write results to JSON."""

import argparse
import json
import sys
from pathlib import Path

import torch

from pawn.config import CLMConfig
from pawn.model import PAWNCLM
from pawn.eval_suite.probes import extract_probe_data, train_all_probes


def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
    from pawn.checkpoint import load_backbone_weights
    state_dict, model_config = load_backbone_weights(checkpoint_path, device)
    if model_config:
        cfg = CLMConfig(**model_config)
    else:
        # Fallback: infer from state dict shapes
        d_model = state_dict["embed.src_embed.weight"].shape[1]
        n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1
        if d_model == 256 and n_layers == 8:
            cfg = CLMConfig.small()
        elif d_model == 512 and n_layers == 8:
            cfg = CLMConfig.base()
        elif d_model == 640 and n_layers == 10:
            cfg = CLMConfig.large()
        else:
            cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
    model = PAWNCLM(cfg).to(device)
    model.load_state_dict(state_dict)
    model.eval()
    return model


def main():
    parser = argparse.ArgumentParser(description="Run linear probes on checkpoints")
    parser.add_argument("--log-dir", type=str, default="logs", help="Log directory containing run dirs")
    parser.add_argument("--n-games", type=int, default=4096, help="Games for probe train set")
    parser.add_argument("--n-val-games", type=int, default=1024, help="Games for probe val set")
    parser.add_argument("--n-epochs", type=int, default=20, help="Probe training epochs")
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--run", type=str, default=None, help="Only evaluate this run dir name")
    args = parser.parse_args()

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    if device == "cuda":
        from pawn.gpu import configure_gpu
        gpu_cfg = configure_gpu()
        import pawn.model as model_module
        model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")

    log_dir = Path(args.log_dir)

    # Find all runs with checkpoints
    runs = []
    for config_path in sorted(log_dir.glob("run_*/config.json")):
        run_dir = config_path.parent
        if args.run and run_dir.name != args.run:
            continue
        # Find checkpoints: directory-based (safetensors) or legacy .pt
        checkpoints = sorted(
            [d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()]
            or list(run_dir.glob("checkpoints/step_*.pt"))
        )
        if not checkpoints:
            continue
        latest = checkpoints[-1]
        with open(config_path) as f:
            cfg = json.load(f)
        runs.append((run_dir, latest, cfg))

    if not runs:
        print("No runs with checkpoints found.")
        sys.exit(1)

    print(f"Found {len(runs)} runs to evaluate")

    # Generate probe data once (shared across all models with same max_ply)
    max_ply = 256
    print(f"\nGenerating probe data: {args.n_games} train + {args.n_val_games} val games...")
    train_data = extract_probe_data(args.n_games, max_ply, seed=12345)
    val_data = extract_probe_data(args.n_val_games, max_ply, seed=54321)
    print("Done.")

    for run_dir, ckpt_path, run_cfg in runs:
        model_cfg = run_cfg.get("model", {})
        train_cfg = run_cfg.get("training", {})
        variant = f"{model_cfg.get('d_model', '?')}d/{model_cfg.get('n_layers', '?')}L"
        discard = train_cfg.get("discard_ply_limit", False)
        step = ckpt_path.stem.replace("step_", "")

        print(f"\n{'='*60}")
        print(f"Run: {run_dir.name}  ({variant}, discard_ply={discard}, step={step})")
        print(f"Checkpoint: {ckpt_path}")
        print(f"{'='*60}")

        model = load_model_from_checkpoint(str(ckpt_path), device)

        results = train_all_probes(
            model, train_data, val_data, device,
            per_layer=True, n_epochs=args.n_epochs, verbose=True,
        )

        # Save results
        output = {
            "run": run_dir.name,
            "checkpoint": str(ckpt_path),
            "step": int(step),
            "variant": variant,
            "discard_ply_limit": discard,
            "model_config": model_cfg,
            "probes": {
                pname: {
                    lname: {k: round(v, 6) if isinstance(v, float) else v for k, v in metrics.items()}
                    for lname, metrics in layer_results.items()
                }
                for pname, layer_results in results.items()
            },
        }

        out_path = run_dir / "probe_results.json"
        with open(out_path, "w") as f:
            json.dump(output, f, indent=2)
        print(f"\nSaved: {out_path}")

        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None


if __name__ == "__main__":
    main()