File size: 8,701 Bytes
34f2e1c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
CPU-only held-out evaluation for BitLoopLM.
Runs safely alongside a live GPU training job — only uses CPU + RAM.
Loads the weights currently saved to $CKPT (default: pytorch_model.bin
written by save_and_push at SAVE_INTERVAL milestones), streams a
held-out slice of the training dataset, and reports pure CE + perplexity.
Config via env vars (all optional):
CKPT pytorch_model.bin / resume.pt path (default: latest hub artifact)
MODEL_SIZE "small" | "tiny" (must match the checkpoint)
NUM_LOOPS number of recurrent loops (must match)
EVAL_DATASET HF dataset name
EVAL_DATASET_CONFIG HF dataset config
EVAL_SKIP samples to discard before taking the eval window (ensures held-out)
EVAL_BATCHES number of evaluation batches
EVAL_SEQ_LEN tokens per sample (smaller = faster on CPU)
EVAL_BATCH_SIZE batch dim
TOKENIZER HF tokenizer id
FAST "1" (default) = pre-freeze BitLinear weights + bf16 cast.
"0" disables for an apples-to-apples baseline comparison.
Usage:
uv run python eval_cpu.py
CKPT=./bitlooplm-checkpoints/resume.pt EVAL_BATCHES=16 uv run python eval_cpu.py
FAST=0 uv run python eval_cpu.py # baseline (slower, fp32, no freeze)
"""
import math
import os
import sys
import time
import torch
import torch.nn.functional as F
# Import model from the training script (same directory)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_bitlooplm_standalone import (
BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS,
BitLinear, weight_quant, activation_quant,
)
CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin")
MODEL_SIZE = os.environ.get("MODEL_SIZE", "small")
NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4"))
TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M")
EVAL_DATASET = os.environ.get("EVAL_DATASET", "HuggingFaceTB/smollm-corpus")
EVAL_DATASET_CONFIG = os.environ.get("EVAL_DATASET_CONFIG", "cosmopedia-v2")
EVAL_SKIP = int(os.environ.get("EVAL_SKIP", "50000"))
EVAL_BATCHES = int(os.environ.get("EVAL_BATCHES", "8"))
EVAL_SEQ_LEN = int(os.environ.get("EVAL_SEQ_LEN", "256"))
EVAL_BATCH_SIZE = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
FAST = os.environ.get("FAST", "1") == "1"
def freeze_bitlinears(model):
"""Pre-quantize BitLinear weights once, swap forward to skip per-call quant + STE.
The training script's BitLinear recomputes weight_quant() on every forward
(correct for STE) and pads it with detach-add tricks (training-only). For
eval, weights are static, so we pay both costs for nothing. Walk the model,
cache the quantized weight as a buffer (so .to(dtype) casts it along with
everything else), and replace forward with a lean version.
Returns the number of layers patched.
"""
import types
n = 0
for module in model.modules():
if isinstance(module, BitLinear):
with torch.no_grad():
qw = weight_quant(module.weight).detach().clone()
module.register_buffer("_quant_weight_cached", qw, persistent=False)
def fast_forward(self, x):
qx = activation_quant(x)
out = F.linear(qx, self._quant_weight_cached)
if self.bias is not None:
out = out + self.bias
return out
module.forward = types.MethodType(fast_forward, module)
n += 1
return n
def main():
torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1))
device = torch.device("cpu")
print(f"[eval] device=cpu threads={torch.get_num_threads()}")
print(f"[eval] loading tokenizer: {TOKENIZER}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
print(f"[eval] building model: size={MODEL_SIZE} loops={NUM_LOOPS}")
cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE])
cfg_dict["num_loops"] = NUM_LOOPS
config = BitLoopLMConfig(**cfg_dict)
model = BitLoopLM(config)
model.eval()
print(f"[eval] loading checkpoint: {CKPT}")
state = torch.load(CKPT, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model" in state:
# resume.pt wraps weights under "model"
state = state["model"]
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[eval] WARNING missing keys: {len(missing)} (e.g., {missing[:3]})")
if unexpected:
print(f"[eval] WARNING unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})")
n_params = sum(p.numel() for p in model.parameters())
print(f"[eval] model loaded, {n_params/1e6:.1f}M params")
if FAST:
# Tier 1+2: bf16 GEMM + pre-frozen BitLinear weights.
torch.set_float32_matmul_precision("medium")
n_frozen = freeze_bitlinears(model)
# Cast AFTER freezing so weight_quant runs on fp32 weights, then the
# cached quantized buffer is cast to bf16 along with everything else.
model = model.to(torch.bfloat16)
print(f"[eval] FAST mode: pre-froze {n_frozen} BitLinears, cast model to bf16")
print(f"[eval] streaming {EVAL_DATASET}/{EVAL_DATASET_CONFIG}, skip={EVAL_SKIP}")
from datasets import load_dataset
ds = load_dataset(
EVAL_DATASET, EVAL_DATASET_CONFIG,
split="train", streaming=True,
)
ds = ds.skip(EVAL_SKIP)
# Collect enough tokens for EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
needed = EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
print(f"[eval] need {needed} tokens for {EVAL_BATCHES} batches of {EVAL_BATCH_SIZE}x{EVAL_SEQ_LEN}")
buffer = []
samples_consumed = 0
t0 = time.time()
for sample in ds:
text = sample.get("text") or sample.get("content") or ""
if not text:
continue
ids = tokenizer.encode(text, add_special_tokens=False)
buffer.extend(ids)
samples_consumed += 1
if len(buffer) >= needed:
break
print(f"[eval] collected {len(buffer)} tokens from {samples_consumed} samples in {time.time()-t0:.1f}s")
total_ce = 0.0
total_tokens = 0
total_forward_time = 0.0
total_loop_ce = torch.zeros(config.num_loops)
total_exit = torch.zeros(config.num_loops)
print(f"[eval] running {EVAL_BATCHES} batches")
with torch.no_grad():
for b in range(EVAL_BATCHES):
off = b * EVAL_BATCH_SIZE * EVAL_SEQ_LEN
chunk = buffer[off: off + EVAL_BATCH_SIZE * EVAL_SEQ_LEN]
batch = torch.tensor(chunk, dtype=torch.long, device=device).view(
EVAL_BATCH_SIZE, EVAL_SEQ_LEN
)
t0 = time.time()
# Inference path: returns (weighted-combined logits, exit_pdf)
logits, exit_pdf = model(batch)
# Cast logits to fp32 for CE — bf16 softmax can underflow on the rare-token tail.
shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = batch[:, 1:].contiguous()
ce_tokens = F.cross_entropy(
shift_logits.view(-1, config.vocab_size),
shift_labels.view(-1),
reduction="sum",
)
total_ce += ce_tokens.item()
total_tokens += shift_labels.numel()
total_exit += exit_pdf.mean(dim=(0, 1)).detach()
dt = time.time() - t0
total_forward_time += dt
# Per-loop CE (inference-mode, unweighted, for diagnostics)
with torch.no_grad():
# Re-run in labels mode to get per-loop stats without extra grad cost
pass # skip — would double the forward cost. Use training logs for per-loop.
avg = total_ce / total_tokens
print(
f" batch {b+1}/{EVAL_BATCHES}: ce={ce_tokens.item()/shift_labels.numel():.4f} "
f"running_avg={avg:.4f} ppl={math.exp(avg):.2f} "
f"forward={dt:.1f}s"
)
avg_ce = total_ce / total_tokens
ppl = math.exp(avg_ce)
exit_mean = (total_exit / EVAL_BATCHES).tolist()
print("\n=== Eval summary ===")
print(f" checkpoint : {CKPT}")
print(f" held-out dataset: {EVAL_DATASET}/{EVAL_DATASET_CONFIG} skip={EVAL_SKIP}")
print(f" tokens evaluated: {total_tokens}")
print(f" avg CE (nats) : {avg_ce:.4f}")
print(f" perplexity : {ppl:.2f}")
print(f" exit pdf : {[f'L{i}:{v:.2f}' for i, v in enumerate(exit_mean)]}")
print(f" forward time : {total_forward_time:.1f}s total ({total_forward_time/EVAL_BATCHES:.1f}s/batch)")
if __name__ == "__main__":
main()
|