bitlooplm-small / eval_cpu.py
wmertens's picture
learnings
34f2e1c
"""
CPU-only held-out evaluation for BitLoopLM.
Runs safely alongside a live GPU training job — only uses CPU + RAM.
Loads the weights currently saved to $CKPT (default: pytorch_model.bin
written by save_and_push at SAVE_INTERVAL milestones), streams a
held-out slice of the training dataset, and reports pure CE + perplexity.
Config via env vars (all optional):
CKPT pytorch_model.bin / resume.pt path (default: latest hub artifact)
MODEL_SIZE "small" | "tiny" (must match the checkpoint)
NUM_LOOPS number of recurrent loops (must match)
EVAL_DATASET HF dataset name
EVAL_DATASET_CONFIG HF dataset config
EVAL_SKIP samples to discard before taking the eval window (ensures held-out)
EVAL_BATCHES number of evaluation batches
EVAL_SEQ_LEN tokens per sample (smaller = faster on CPU)
EVAL_BATCH_SIZE batch dim
TOKENIZER HF tokenizer id
FAST "1" (default) = pre-freeze BitLinear weights + bf16 cast.
"0" disables for an apples-to-apples baseline comparison.
Usage:
uv run python eval_cpu.py
CKPT=./bitlooplm-checkpoints/resume.pt EVAL_BATCHES=16 uv run python eval_cpu.py
FAST=0 uv run python eval_cpu.py # baseline (slower, fp32, no freeze)
"""
import math
import os
import sys
import time
import torch
import torch.nn.functional as F
# Import model from the training script (same directory)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_bitlooplm_standalone import (
BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS,
BitLinear, weight_quant, activation_quant,
)
CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin")
MODEL_SIZE = os.environ.get("MODEL_SIZE", "small")
NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4"))
TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M")
EVAL_DATASET = os.environ.get("EVAL_DATASET", "HuggingFaceTB/smollm-corpus")
EVAL_DATASET_CONFIG = os.environ.get("EVAL_DATASET_CONFIG", "cosmopedia-v2")
EVAL_SKIP = int(os.environ.get("EVAL_SKIP", "50000"))
EVAL_BATCHES = int(os.environ.get("EVAL_BATCHES", "8"))
EVAL_SEQ_LEN = int(os.environ.get("EVAL_SEQ_LEN", "256"))
EVAL_BATCH_SIZE = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
FAST = os.environ.get("FAST", "1") == "1"
def freeze_bitlinears(model):
"""Pre-quantize BitLinear weights once, swap forward to skip per-call quant + STE.
The training script's BitLinear recomputes weight_quant() on every forward
(correct for STE) and pads it with detach-add tricks (training-only). For
eval, weights are static, so we pay both costs for nothing. Walk the model,
cache the quantized weight as a buffer (so .to(dtype) casts it along with
everything else), and replace forward with a lean version.
Returns the number of layers patched.
"""
import types
n = 0
for module in model.modules():
if isinstance(module, BitLinear):
with torch.no_grad():
qw = weight_quant(module.weight).detach().clone()
module.register_buffer("_quant_weight_cached", qw, persistent=False)
def fast_forward(self, x):
qx = activation_quant(x)
out = F.linear(qx, self._quant_weight_cached)
if self.bias is not None:
out = out + self.bias
return out
module.forward = types.MethodType(fast_forward, module)
n += 1
return n
def main():
torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1))
device = torch.device("cpu")
print(f"[eval] device=cpu threads={torch.get_num_threads()}")
print(f"[eval] loading tokenizer: {TOKENIZER}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
print(f"[eval] building model: size={MODEL_SIZE} loops={NUM_LOOPS}")
cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE])
cfg_dict["num_loops"] = NUM_LOOPS
config = BitLoopLMConfig(**cfg_dict)
model = BitLoopLM(config)
model.eval()
print(f"[eval] loading checkpoint: {CKPT}")
state = torch.load(CKPT, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model" in state:
# resume.pt wraps weights under "model"
state = state["model"]
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[eval] WARNING missing keys: {len(missing)} (e.g., {missing[:3]})")
if unexpected:
print(f"[eval] WARNING unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})")
n_params = sum(p.numel() for p in model.parameters())
print(f"[eval] model loaded, {n_params/1e6:.1f}M params")
if FAST:
# Tier 1+2: bf16 GEMM + pre-frozen BitLinear weights.
torch.set_float32_matmul_precision("medium")
n_frozen = freeze_bitlinears(model)
# Cast AFTER freezing so weight_quant runs on fp32 weights, then the
# cached quantized buffer is cast to bf16 along with everything else.
model = model.to(torch.bfloat16)
print(f"[eval] FAST mode: pre-froze {n_frozen} BitLinears, cast model to bf16")
print(f"[eval] streaming {EVAL_DATASET}/{EVAL_DATASET_CONFIG}, skip={EVAL_SKIP}")
from datasets import load_dataset
ds = load_dataset(
EVAL_DATASET, EVAL_DATASET_CONFIG,
split="train", streaming=True,
)
ds = ds.skip(EVAL_SKIP)
# Collect enough tokens for EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
needed = EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
print(f"[eval] need {needed} tokens for {EVAL_BATCHES} batches of {EVAL_BATCH_SIZE}x{EVAL_SEQ_LEN}")
buffer = []
samples_consumed = 0
t0 = time.time()
for sample in ds:
text = sample.get("text") or sample.get("content") or ""
if not text:
continue
ids = tokenizer.encode(text, add_special_tokens=False)
buffer.extend(ids)
samples_consumed += 1
if len(buffer) >= needed:
break
print(f"[eval] collected {len(buffer)} tokens from {samples_consumed} samples in {time.time()-t0:.1f}s")
total_ce = 0.0
total_tokens = 0
total_forward_time = 0.0
total_loop_ce = torch.zeros(config.num_loops)
total_exit = torch.zeros(config.num_loops)
print(f"[eval] running {EVAL_BATCHES} batches")
with torch.no_grad():
for b in range(EVAL_BATCHES):
off = b * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
chunk = buffer[off: off + EVAL_BATCH_SIZE * EVAL_SEQ_LEN]
batch = torch.tensor(chunk, dtype=torch.long, device=device).view(
EVAL_BATCH_SIZE, EVAL_SEQ_LEN
)
t0 = time.time()
# Inference path: returns (weighted-combined logits, exit_pdf)
logits, exit_pdf = model(batch)
# Cast logits to fp32 for CE — bf16 softmax can underflow on the rare-token tail.
shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = batch[:, 1:].contiguous()
ce_tokens = F.cross_entropy(
shift_logits.view(-1, config.vocab_size),
shift_labels.view(-1),
reduction="sum",
)
total_ce += ce_tokens.item()
total_tokens += shift_labels.numel()
total_exit += exit_pdf.mean(dim=(0, 1)).detach()
dt = time.time() - t0
total_forward_time += dt
# Per-loop CE (inference-mode, unweighted, for diagnostics)
with torch.no_grad():
# Re-run in labels mode to get per-loop stats without extra grad cost
pass # skip — would double the forward cost. Use training logs for per-loop.
avg = total_ce / total_tokens
print(
f" batch {b+1}/{EVAL_BATCHES}: ce={ce_tokens.item()/shift_labels.numel():.4f} "
f"running_avg={avg:.4f} ppl={math.exp(avg):.2f} "
f"forward={dt:.1f}s"
)
avg_ce = total_ce / total_tokens
ppl = math.exp(avg_ce)
exit_mean = (total_exit / EVAL_BATCHES).tolist()
print("\n=== Eval summary ===")
print(f" checkpoint : {CKPT}")
print(f" held-out dataset: {EVAL_DATASET}/{EVAL_DATASET_CONFIG} skip={EVAL_SKIP}")
print(f" tokens evaluated: {total_tokens}")
print(f" avg CE (nats) : {avg_ce:.4f}")
print(f" perplexity : {ppl:.2f}")
print(f" exit pdf : {[f'L{i}:{v:.2f}' for i, v in enumerate(exit_mean)]}")
print(f" forward time : {total_forward_time:.1f}s total ({total_forward_time/EVAL_BATCHES:.1f}s/batch)")
if __name__ == "__main__":
main()