""" Analyse how 'committed' the master weights are to their current ternary bins. For each BitLinear layer: - Each fp32 weight w is scaled by s = 1/mean(|w|), rounded, clamped to {-1,0,+1}. - Bin boundaries (in scaled space) are ±0.5. - A weight is N Lion-steps away from flipping if it's within N*lr of the nearest boundary in unscaled-weight space, AND the gradient sign happens to push it across. (So this is a *lower bound* on steps to flip.) Reads ./bitlooplm-checkpoints/resume.pt (or override via CKPT). Configurable LR (default = current cosine-floor LR ≈ 5.4e-5). """ import os, sys, math import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, BitLinear CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/resume.pt") LR = float(os.environ.get("LR", "5.4e-5")) MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) N_STEPS_BUCKETS = [1, 10, 100, 1000, 10000, 100000] def main(): cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) cfg_dict["num_loops"] = NUM_LOOPS config = BitLoopLMConfig(**cfg_dict) model = BitLoopLM(config) state = torch.load(CKPT, map_location="cpu", weights_only=False) if isinstance(state, dict) and "model" in state: state = state["model"] model.load_state_dict(state, strict=False) print(f"checkpoint: {CKPT}") print(f"Lion LR : {LR:.2e} (per-step weight change magnitude ≈ this)") print() total_w = 0 total_bin = {-1: 0, 0: 0, 1: 0} total_in_steps = {n: 0 for n in N_STEPS_BUCKETS} layer_summary = [] for name, mod in model.named_modules(): if not isinstance(mod, BitLinear): continue with torch.no_grad(): w = mod.weight.detach().float() s = 1.0 / w.abs().mean().clamp(min=1e-5) sw = w * s # scaled to [-something, +something] bin_assign = sw.round().clamp(-1, 1) # Distance to nearest bin boundary in scaled space: # boundaries are at ±0.5. A weight is in bin 0 if |sw|<0.5; otherwise ±1. # Either way, distance to crossover = ||sw| - 0.5|. dist_scaled = (sw.abs() - 0.5).abs() dist_weight = dist_scaled / s min_steps = dist_weight / LR n_w = w.numel() n_neg = int((bin_assign == -1).sum()) n_zero = int((bin_assign == 0).sum()) n_pos = int((bin_assign == 1).sum()) total_w += n_w total_bin[-1] += n_neg total_bin[0] += n_zero total_bin[1] += n_pos in_steps_layer = {} for n in N_STEPS_BUCKETS: count = int((min_steps <= n).sum()) in_steps_layer[n] = count total_in_steps[n] += count layer_summary.append((name, n_w, s.item(), n_neg, n_zero, n_pos, in_steps_layer)) # Aggregate report first (most useful) print(f"=== Aggregate over {total_w:,} BitLinear weights ===") pct = lambda x: 100 * x / total_w print(f" Bin distribution: -1 {pct(total_bin[-1]):5.2f}% 0 {pct(total_bin[0]):5.2f}% +1 {pct(total_bin[1]):5.2f}%") print(f" Weights within N steps of flipping (best-case, assumes consistent gradient):") for n in N_STEPS_BUCKETS: c = total_in_steps[n] print(f" ≤ {n:>6d} steps: {pct(c):6.4f}% ({c:>10,} weights)") # Per-layer breakdown print() print("=== Per-layer breakdown ===") print(f"{'layer':70s} {'n_weights':>12s} {'scale_s':>8s} {'-1%':>5s} {'0%':>5s} {'+1%':>5s} {'≤10 step%':>9s} {'≤1k step%':>9s} {'≤10k step%':>10s}") for name, n_w, s, n_neg, n_zero, n_pos, in_steps in layer_summary: p = lambda x: 100 * x / n_w print(f"{name:70s} {n_w:>12,} {s:>8.4f} " f"{p(n_neg):>4.1f} {p(n_zero):>4.1f} {p(n_pos):>4.1f} " f"{p(in_steps[10]):>8.4f} {p(in_steps[1000]):>8.4f} {p(in_steps[10000]):>9.4f}") if __name__ == "__main__": main()