| """ |
| Analyse how 'committed' the master weights are to their current ternary bins. |
| |
| For each BitLinear layer: |
| - Each fp32 weight w is scaled by s = 1/mean(|w|), rounded, clamped to {-1,0,+1}. |
| - Bin boundaries (in scaled space) are ±0.5. |
| - A weight is N Lion-steps away from flipping if it's within N*lr of the |
| nearest boundary in unscaled-weight space, AND the gradient sign happens |
| to push it across. (So this is a *lower bound* on steps to flip.) |
| |
| Reads ./bitlooplm-checkpoints/resume.pt (or override via CKPT). |
| Configurable LR (default = current cosine-floor LR ≈ 5.4e-5). |
| """ |
| import os, sys, math |
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, BitLinear |
|
|
| CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/resume.pt") |
| LR = float(os.environ.get("LR", "5.4e-5")) |
| MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") |
| NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) |
| N_STEPS_BUCKETS = [1, 10, 100, 1000, 10000, 100000] |
|
|
|
|
| def main(): |
| cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) |
| cfg_dict["num_loops"] = NUM_LOOPS |
| config = BitLoopLMConfig(**cfg_dict) |
| model = BitLoopLM(config) |
| state = torch.load(CKPT, map_location="cpu", weights_only=False) |
| if isinstance(state, dict) and "model" in state: |
| state = state["model"] |
| model.load_state_dict(state, strict=False) |
|
|
| print(f"checkpoint: {CKPT}") |
| print(f"Lion LR : {LR:.2e} (per-step weight change magnitude ≈ this)") |
| print() |
|
|
| total_w = 0 |
| total_bin = {-1: 0, 0: 0, 1: 0} |
| total_in_steps = {n: 0 for n in N_STEPS_BUCKETS} |
| layer_summary = [] |
|
|
| for name, mod in model.named_modules(): |
| if not isinstance(mod, BitLinear): |
| continue |
| with torch.no_grad(): |
| w = mod.weight.detach().float() |
| s = 1.0 / w.abs().mean().clamp(min=1e-5) |
| sw = w * s |
| bin_assign = sw.round().clamp(-1, 1) |
| |
| |
| |
| dist_scaled = (sw.abs() - 0.5).abs() |
| dist_weight = dist_scaled / s |
| min_steps = dist_weight / LR |
|
|
| n_w = w.numel() |
| n_neg = int((bin_assign == -1).sum()) |
| n_zero = int((bin_assign == 0).sum()) |
| n_pos = int((bin_assign == 1).sum()) |
| total_w += n_w |
| total_bin[-1] += n_neg |
| total_bin[0] += n_zero |
| total_bin[1] += n_pos |
|
|
| in_steps_layer = {} |
| for n in N_STEPS_BUCKETS: |
| count = int((min_steps <= n).sum()) |
| in_steps_layer[n] = count |
| total_in_steps[n] += count |
|
|
| layer_summary.append((name, n_w, s.item(), n_neg, n_zero, n_pos, in_steps_layer)) |
|
|
| |
| print(f"=== Aggregate over {total_w:,} BitLinear weights ===") |
| pct = lambda x: 100 * x / total_w |
| print(f" Bin distribution: -1 {pct(total_bin[-1]):5.2f}% 0 {pct(total_bin[0]):5.2f}% +1 {pct(total_bin[1]):5.2f}%") |
| print(f" Weights within N steps of flipping (best-case, assumes consistent gradient):") |
| for n in N_STEPS_BUCKETS: |
| c = total_in_steps[n] |
| print(f" ≤ {n:>6d} steps: {pct(c):6.4f}% ({c:>10,} weights)") |
|
|
| |
| print() |
| print("=== Per-layer breakdown ===") |
| print(f"{'layer':70s} {'n_weights':>12s} {'scale_s':>8s} {'-1%':>5s} {'0%':>5s} {'+1%':>5s} {'≤10 step%':>9s} {'≤1k step%':>9s} {'≤10k step%':>10s}") |
| for name, n_w, s, n_neg, n_zero, n_pos, in_steps in layer_summary: |
| p = lambda x: 100 * x / n_w |
| print(f"{name:70s} {n_w:>12,} {s:>8.4f} " |
| f"{p(n_neg):>4.1f} {p(n_zero):>4.1f} {p(n_pos):>4.1f} " |
| f"{p(in_steps[10]):>8.4f} {p(in_steps[1000]):>8.4f} {p(in_steps[10000]):>9.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|