File size: 4,104 Bytes
34f2e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Analyse how 'committed' the master weights are to their current ternary bins.

For each BitLinear layer:
  - Each fp32 weight w is scaled by s = 1/mean(|w|), rounded, clamped to {-1,0,+1}.
  - Bin boundaries (in scaled space) are ±0.5.
  - A weight is N Lion-steps away from flipping if it's within N*lr of the
    nearest boundary in unscaled-weight space, AND the gradient sign happens
    to push it across. (So this is a *lower bound* on steps to flip.)

Reads ./bitlooplm-checkpoints/resume.pt (or override via CKPT).
Configurable LR (default = current cosine-floor LR ≈ 5.4e-5).
"""
import os, sys, math
import torch

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, BitLinear

CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/resume.pt")
LR = float(os.environ.get("LR", "5.4e-5"))
MODEL_SIZE = os.environ.get("MODEL_SIZE", "small")
NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4"))
N_STEPS_BUCKETS = [1, 10, 100, 1000, 10000, 100000]


def main():
    cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE])
    cfg_dict["num_loops"] = NUM_LOOPS
    config = BitLoopLMConfig(**cfg_dict)
    model = BitLoopLM(config)
    state = torch.load(CKPT, map_location="cpu", weights_only=False)
    if isinstance(state, dict) and "model" in state:
        state = state["model"]
    model.load_state_dict(state, strict=False)

    print(f"checkpoint: {CKPT}")
    print(f"Lion LR    : {LR:.2e}  (per-step weight change magnitude ≈ this)")
    print()

    total_w = 0
    total_bin = {-1: 0, 0: 0, 1: 0}
    total_in_steps = {n: 0 for n in N_STEPS_BUCKETS}
    layer_summary = []

    for name, mod in model.named_modules():
        if not isinstance(mod, BitLinear):
            continue
        with torch.no_grad():
            w = mod.weight.detach().float()
            s = 1.0 / w.abs().mean().clamp(min=1e-5)
            sw = w * s  # scaled to [-something, +something]
            bin_assign = sw.round().clamp(-1, 1)
            # Distance to nearest bin boundary in scaled space:
            # boundaries are at ±0.5. A weight is in bin 0 if |sw|<0.5; otherwise ±1.
            # Either way, distance to crossover = ||sw| - 0.5|.
            dist_scaled = (sw.abs() - 0.5).abs()
            dist_weight = dist_scaled / s
            min_steps = dist_weight / LR

            n_w = w.numel()
            n_neg = int((bin_assign == -1).sum())
            n_zero = int((bin_assign == 0).sum())
            n_pos = int((bin_assign == 1).sum())
            total_w += n_w
            total_bin[-1] += n_neg
            total_bin[0] += n_zero
            total_bin[1] += n_pos

            in_steps_layer = {}
            for n in N_STEPS_BUCKETS:
                count = int((min_steps <= n).sum())
                in_steps_layer[n] = count
                total_in_steps[n] += count

            layer_summary.append((name, n_w, s.item(), n_neg, n_zero, n_pos, in_steps_layer))

    # Aggregate report first (most useful)
    print(f"=== Aggregate over {total_w:,} BitLinear weights ===")
    pct = lambda x: 100 * x / total_w
    print(f"  Bin distribution: -1 {pct(total_bin[-1]):5.2f}%   0 {pct(total_bin[0]):5.2f}%   +1 {pct(total_bin[1]):5.2f}%")
    print(f"  Weights within N steps of flipping (best-case, assumes consistent gradient):")
    for n in N_STEPS_BUCKETS:
        c = total_in_steps[n]
        print(f"    ≤ {n:>6d} steps: {pct(c):6.4f}%   ({c:>10,} weights)")

    # Per-layer breakdown
    print()
    print("=== Per-layer breakdown ===")
    print(f"{'layer':70s}  {'n_weights':>12s}  {'scale_s':>8s}  {'-1%':>5s}  {'0%':>5s}  {'+1%':>5s}  {'≤10 step%':>9s}  {'≤1k step%':>9s}  {'≤10k step%':>10s}")
    for name, n_w, s, n_neg, n_zero, n_pos, in_steps in layer_summary:
        p = lambda x: 100 * x / n_w
        print(f"{name:70s}  {n_w:>12,}  {s:>8.4f}  "
              f"{p(n_neg):>4.1f}  {p(n_zero):>4.1f}  {p(n_pos):>4.1f}  "
              f"{p(in_steps[10]):>8.4f}  {p(in_steps[1000]):>8.4f}  {p(in_steps[10000]):>9.4f}")


if __name__ == "__main__":
    main()