| """ |
| CPU-only held-out evaluation for BitLoopLM. |
| |
| Runs safely alongside a live GPU training job — only uses CPU + RAM. |
| Loads the weights currently saved to $CKPT (default: pytorch_model.bin |
| written by save_and_push at SAVE_INTERVAL milestones), streams a |
| held-out slice of the training dataset, and reports pure CE + perplexity. |
| |
| Config via env vars (all optional): |
| CKPT pytorch_model.bin / resume.pt path (default: latest hub artifact) |
| MODEL_SIZE "small" | "tiny" (must match the checkpoint) |
| NUM_LOOPS number of recurrent loops (must match) |
| EVAL_DATASET HF dataset name |
| EVAL_DATASET_CONFIG HF dataset config |
| EVAL_SKIP samples to discard before taking the eval window (ensures held-out) |
| EVAL_BATCHES number of evaluation batches |
| EVAL_SEQ_LEN tokens per sample (smaller = faster on CPU) |
| EVAL_BATCH_SIZE batch dim |
| TOKENIZER HF tokenizer id |
| FAST "1" (default) = pre-freeze BitLinear weights + bf16 cast. |
| "0" disables for an apples-to-apples baseline comparison. |
| |
| Usage: |
| uv run python eval_cpu.py |
| CKPT=./bitlooplm-checkpoints/resume.pt EVAL_BATCHES=16 uv run python eval_cpu.py |
| FAST=0 uv run python eval_cpu.py # baseline (slower, fp32, no freeze) |
| """ |
| import math |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from train_bitlooplm_standalone import ( |
| BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, |
| BitLinear, weight_quant, activation_quant, |
| ) |
|
|
| CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin") |
| MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") |
| NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) |
| TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M") |
| EVAL_DATASET = os.environ.get("EVAL_DATASET", "HuggingFaceTB/smollm-corpus") |
| EVAL_DATASET_CONFIG = os.environ.get("EVAL_DATASET_CONFIG", "cosmopedia-v2") |
| EVAL_SKIP = int(os.environ.get("EVAL_SKIP", "50000")) |
| EVAL_BATCHES = int(os.environ.get("EVAL_BATCHES", "8")) |
| EVAL_SEQ_LEN = int(os.environ.get("EVAL_SEQ_LEN", "256")) |
| EVAL_BATCH_SIZE = int(os.environ.get("EVAL_BATCH_SIZE", "4")) |
| FAST = os.environ.get("FAST", "1") == "1" |
|
|
|
|
| def freeze_bitlinears(model): |
| """Pre-quantize BitLinear weights once, swap forward to skip per-call quant + STE. |
| |
| The training script's BitLinear recomputes weight_quant() on every forward |
| (correct for STE) and pads it with detach-add tricks (training-only). For |
| eval, weights are static, so we pay both costs for nothing. Walk the model, |
| cache the quantized weight as a buffer (so .to(dtype) casts it along with |
| everything else), and replace forward with a lean version. |
| |
| Returns the number of layers patched. |
| """ |
| import types |
| n = 0 |
| for module in model.modules(): |
| if isinstance(module, BitLinear): |
| with torch.no_grad(): |
| qw = weight_quant(module.weight).detach().clone() |
| module.register_buffer("_quant_weight_cached", qw, persistent=False) |
|
|
| def fast_forward(self, x): |
| qx = activation_quant(x) |
| out = F.linear(qx, self._quant_weight_cached) |
| if self.bias is not None: |
| out = out + self.bias |
| return out |
|
|
| module.forward = types.MethodType(fast_forward, module) |
| n += 1 |
| return n |
|
|
|
|
| def main(): |
| torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1)) |
| device = torch.device("cpu") |
|
|
| print(f"[eval] device=cpu threads={torch.get_num_threads()}") |
| print(f"[eval] loading tokenizer: {TOKENIZER}") |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
|
|
| print(f"[eval] building model: size={MODEL_SIZE} loops={NUM_LOOPS}") |
| cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) |
| cfg_dict["num_loops"] = NUM_LOOPS |
| config = BitLoopLMConfig(**cfg_dict) |
| model = BitLoopLM(config) |
| model.eval() |
|
|
| print(f"[eval] loading checkpoint: {CKPT}") |
| state = torch.load(CKPT, map_location="cpu", weights_only=False) |
| if isinstance(state, dict) and "model" in state: |
| |
| state = state["model"] |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if missing: |
| print(f"[eval] WARNING missing keys: {len(missing)} (e.g., {missing[:3]})") |
| if unexpected: |
| print(f"[eval] WARNING unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})") |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[eval] model loaded, {n_params/1e6:.1f}M params") |
|
|
| if FAST: |
| |
| torch.set_float32_matmul_precision("medium") |
| n_frozen = freeze_bitlinears(model) |
| |
| |
| model = model.to(torch.bfloat16) |
| print(f"[eval] FAST mode: pre-froze {n_frozen} BitLinears, cast model to bf16") |
|
|
| print(f"[eval] streaming {EVAL_DATASET}/{EVAL_DATASET_CONFIG}, skip={EVAL_SKIP}") |
| from datasets import load_dataset |
| ds = load_dataset( |
| EVAL_DATASET, EVAL_DATASET_CONFIG, |
| split="train", streaming=True, |
| ) |
| ds = ds.skip(EVAL_SKIP) |
|
|
| |
| needed = EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN |
| print(f"[eval] need {needed} tokens for {EVAL_BATCHES} batches of {EVAL_BATCH_SIZE}x{EVAL_SEQ_LEN}") |
|
|
| buffer = [] |
| samples_consumed = 0 |
| t0 = time.time() |
| for sample in ds: |
| text = sample.get("text") or sample.get("content") or "" |
| if not text: |
| continue |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| buffer.extend(ids) |
| samples_consumed += 1 |
| if len(buffer) >= needed: |
| break |
| print(f"[eval] collected {len(buffer)} tokens from {samples_consumed} samples in {time.time()-t0:.1f}s") |
|
|
| total_ce = 0.0 |
| total_tokens = 0 |
| total_forward_time = 0.0 |
| total_loop_ce = torch.zeros(config.num_loops) |
| total_exit = torch.zeros(config.num_loops) |
|
|
| print(f"[eval] running {EVAL_BATCHES} batches") |
| with torch.no_grad(): |
| for b in range(EVAL_BATCHES): |
| off = b * EVAL_BATCH_SIZE * EVAL_SEQ_LEN |
| chunk = buffer[off: off + EVAL_BATCH_SIZE * EVAL_SEQ_LEN] |
| batch = torch.tensor(chunk, dtype=torch.long, device=device).view( |
| EVAL_BATCH_SIZE, EVAL_SEQ_LEN |
| ) |
|
|
| t0 = time.time() |
| |
| logits, exit_pdf = model(batch) |
| |
| shift_logits = logits[:, :-1, :].contiguous().float() |
| shift_labels = batch[:, 1:].contiguous() |
| ce_tokens = F.cross_entropy( |
| shift_logits.view(-1, config.vocab_size), |
| shift_labels.view(-1), |
| reduction="sum", |
| ) |
| total_ce += ce_tokens.item() |
| total_tokens += shift_labels.numel() |
| total_exit += exit_pdf.mean(dim=(0, 1)).detach() |
| dt = time.time() - t0 |
| total_forward_time += dt |
|
|
| |
| with torch.no_grad(): |
| |
| pass |
|
|
| avg = total_ce / total_tokens |
| print( |
| f" batch {b+1}/{EVAL_BATCHES}: ce={ce_tokens.item()/shift_labels.numel():.4f} " |
| f"running_avg={avg:.4f} ppl={math.exp(avg):.2f} " |
| f"forward={dt:.1f}s" |
| ) |
|
|
| avg_ce = total_ce / total_tokens |
| ppl = math.exp(avg_ce) |
| exit_mean = (total_exit / EVAL_BATCHES).tolist() |
|
|
| print("\n=== Eval summary ===") |
| print(f" checkpoint : {CKPT}") |
| print(f" held-out dataset: {EVAL_DATASET}/{EVAL_DATASET_CONFIG} skip={EVAL_SKIP}") |
| print(f" tokens evaluated: {total_tokens}") |
| print(f" avg CE (nats) : {avg_ce:.4f}") |
| print(f" perplexity : {ppl:.2f}") |
| print(f" exit pdf : {[f'L{i}:{v:.2f}' for i, v in enumerate(exit_mean)]}") |
| print(f" forward time : {total_forward_time:.1f}s total ({total_forward_time/EVAL_BATCHES:.1f}s/batch)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|