bitlooplm-small / analyze_quant_stability.py
wmertens's picture
learnings
34f2e1c
"""
Analyse how 'committed' the master weights are to their current ternary bins.
For each BitLinear layer:
- Each fp32 weight w is scaled by s = 1/mean(|w|), rounded, clamped to {-1,0,+1}.
- Bin boundaries (in scaled space) are ±0.5.
- A weight is N Lion-steps away from flipping if it's within N*lr of the
nearest boundary in unscaled-weight space, AND the gradient sign happens
to push it across. (So this is a *lower bound* on steps to flip.)
Reads ./bitlooplm-checkpoints/resume.pt (or override via CKPT).
Configurable LR (default = current cosine-floor LR ≈ 5.4e-5).
"""
import os, sys, math
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, BitLinear
CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/resume.pt")
LR = float(os.environ.get("LR", "5.4e-5"))
MODEL_SIZE = os.environ.get("MODEL_SIZE", "small")
NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4"))
N_STEPS_BUCKETS = [1, 10, 100, 1000, 10000, 100000]
def main():
cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE])
cfg_dict["num_loops"] = NUM_LOOPS
config = BitLoopLMConfig(**cfg_dict)
model = BitLoopLM(config)
state = torch.load(CKPT, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model" in state:
state = state["model"]
model.load_state_dict(state, strict=False)
print(f"checkpoint: {CKPT}")
print(f"Lion LR : {LR:.2e} (per-step weight change magnitude ≈ this)")
print()
total_w = 0
total_bin = {-1: 0, 0: 0, 1: 0}
total_in_steps = {n: 0 for n in N_STEPS_BUCKETS}
layer_summary = []
for name, mod in model.named_modules():
if not isinstance(mod, BitLinear):
continue
with torch.no_grad():
w = mod.weight.detach().float()
s = 1.0 / w.abs().mean().clamp(min=1e-5)
sw = w * s # scaled to [-something, +something]
bin_assign = sw.round().clamp(-1, 1)
# Distance to nearest bin boundary in scaled space:
# boundaries are at ±0.5. A weight is in bin 0 if |sw|<0.5; otherwise ±1.
# Either way, distance to crossover = ||sw| - 0.5|.
dist_scaled = (sw.abs() - 0.5).abs()
dist_weight = dist_scaled / s
min_steps = dist_weight / LR
n_w = w.numel()
n_neg = int((bin_assign == -1).sum())
n_zero = int((bin_assign == 0).sum())
n_pos = int((bin_assign == 1).sum())
total_w += n_w
total_bin[-1] += n_neg
total_bin[0] += n_zero
total_bin[1] += n_pos
in_steps_layer = {}
for n in N_STEPS_BUCKETS:
count = int((min_steps <= n).sum())
in_steps_layer[n] = count
total_in_steps[n] += count
layer_summary.append((name, n_w, s.item(), n_neg, n_zero, n_pos, in_steps_layer))
# Aggregate report first (most useful)
print(f"=== Aggregate over {total_w:,} BitLinear weights ===")
pct = lambda x: 100 * x / total_w
print(f" Bin distribution: -1 {pct(total_bin[-1]):5.2f}% 0 {pct(total_bin[0]):5.2f}% +1 {pct(total_bin[1]):5.2f}%")
print(f" Weights within N steps of flipping (best-case, assumes consistent gradient):")
for n in N_STEPS_BUCKETS:
c = total_in_steps[n]
print(f" ≤ {n:>6d} steps: {pct(c):6.4f}% ({c:>10,} weights)")
# Per-layer breakdown
print()
print("=== Per-layer breakdown ===")
print(f"{'layer':70s} {'n_weights':>12s} {'scale_s':>8s} {'-1%':>5s} {'0%':>5s} {'+1%':>5s} {'≤10 step%':>9s} {'≤1k step%':>9s} {'≤10k step%':>10s}")
for name, n_w, s, n_neg, n_zero, n_pos, in_steps in layer_summary:
p = lambda x: 100 * x / n_w
print(f"{name:70s} {n_w:>12,} {s:>8.4f} "
f"{p(n_neg):>4.1f} {p(n_zero):>4.1f} {p(n_pos):>4.1f} "
f"{p(in_steps[10]):>8.4f} {p(in_steps[1000]):>8.4f} {p(in_steps[10000]):>9.4f}")
if __name__ == "__main__":
main()