""" CPU-only held-out evaluation for BitLoopLM. Runs safely alongside a live GPU training job — only uses CPU + RAM. Loads the weights currently saved to $CKPT (default: pytorch_model.bin written by save_and_push at SAVE_INTERVAL milestones), streams a held-out slice of the training dataset, and reports pure CE + perplexity. Config via env vars (all optional): CKPT pytorch_model.bin / resume.pt path (default: latest hub artifact) MODEL_SIZE "small" | "tiny" (must match the checkpoint) NUM_LOOPS number of recurrent loops (must match) EVAL_DATASET HF dataset name EVAL_DATASET_CONFIG HF dataset config EVAL_SKIP samples to discard before taking the eval window (ensures held-out) EVAL_BATCHES number of evaluation batches EVAL_SEQ_LEN tokens per sample (smaller = faster on CPU) EVAL_BATCH_SIZE batch dim TOKENIZER HF tokenizer id FAST "1" (default) = pre-freeze BitLinear weights + bf16 cast. "0" disables for an apples-to-apples baseline comparison. Usage: uv run python eval_cpu.py CKPT=./bitlooplm-checkpoints/resume.pt EVAL_BATCHES=16 uv run python eval_cpu.py FAST=0 uv run python eval_cpu.py # baseline (slower, fp32, no freeze) """ import math import os import sys import time import torch import torch.nn.functional as F # Import model from the training script (same directory) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from train_bitlooplm_standalone import ( BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, BitLinear, weight_quant, activation_quant, ) CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin") MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M") EVAL_DATASET = os.environ.get("EVAL_DATASET", "HuggingFaceTB/smollm-corpus") EVAL_DATASET_CONFIG = os.environ.get("EVAL_DATASET_CONFIG", "cosmopedia-v2") EVAL_SKIP = int(os.environ.get("EVAL_SKIP", "50000")) EVAL_BATCHES = int(os.environ.get("EVAL_BATCHES", "8")) EVAL_SEQ_LEN = int(os.environ.get("EVAL_SEQ_LEN", "256")) EVAL_BATCH_SIZE = int(os.environ.get("EVAL_BATCH_SIZE", "4")) FAST = os.environ.get("FAST", "1") == "1" def freeze_bitlinears(model): """Pre-quantize BitLinear weights once, swap forward to skip per-call quant + STE. The training script's BitLinear recomputes weight_quant() on every forward (correct for STE) and pads it with detach-add tricks (training-only). For eval, weights are static, so we pay both costs for nothing. Walk the model, cache the quantized weight as a buffer (so .to(dtype) casts it along with everything else), and replace forward with a lean version. Returns the number of layers patched. """ import types n = 0 for module in model.modules(): if isinstance(module, BitLinear): with torch.no_grad(): qw = weight_quant(module.weight).detach().clone() module.register_buffer("_quant_weight_cached", qw, persistent=False) def fast_forward(self, x): qx = activation_quant(x) out = F.linear(qx, self._quant_weight_cached) if self.bias is not None: out = out + self.bias return out module.forward = types.MethodType(fast_forward, module) n += 1 return n def main(): torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1)) device = torch.device("cpu") print(f"[eval] device=cpu threads={torch.get_num_threads()}") print(f"[eval] loading tokenizer: {TOKENIZER}") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) print(f"[eval] building model: size={MODEL_SIZE} loops={NUM_LOOPS}") cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) cfg_dict["num_loops"] = NUM_LOOPS config = BitLoopLMConfig(**cfg_dict) model = BitLoopLM(config) model.eval() print(f"[eval] loading checkpoint: {CKPT}") state = torch.load(CKPT, map_location="cpu", weights_only=False) if isinstance(state, dict) and "model" in state: # resume.pt wraps weights under "model" state = state["model"] missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"[eval] WARNING missing keys: {len(missing)} (e.g., {missing[:3]})") if unexpected: print(f"[eval] WARNING unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})") n_params = sum(p.numel() for p in model.parameters()) print(f"[eval] model loaded, {n_params/1e6:.1f}M params") if FAST: # Tier 1+2: bf16 GEMM + pre-frozen BitLinear weights. torch.set_float32_matmul_precision("medium") n_frozen = freeze_bitlinears(model) # Cast AFTER freezing so weight_quant runs on fp32 weights, then the # cached quantized buffer is cast to bf16 along with everything else. model = model.to(torch.bfloat16) print(f"[eval] FAST mode: pre-froze {n_frozen} BitLinears, cast model to bf16") print(f"[eval] streaming {EVAL_DATASET}/{EVAL_DATASET_CONFIG}, skip={EVAL_SKIP}") from datasets import load_dataset ds = load_dataset( EVAL_DATASET, EVAL_DATASET_CONFIG, split="train", streaming=True, ) ds = ds.skip(EVAL_SKIP) # Collect enough tokens for EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN needed = EVAL_BATCHES * EVAL_BATCH_SIZE * EVAL_SEQ_LEN print(f"[eval] need {needed} tokens for {EVAL_BATCHES} batches of {EVAL_BATCH_SIZE}x{EVAL_SEQ_LEN}") buffer = [] samples_consumed = 0 t0 = time.time() for sample in ds: text = sample.get("text") or sample.get("content") or "" if not text: continue ids = tokenizer.encode(text, add_special_tokens=False) buffer.extend(ids) samples_consumed += 1 if len(buffer) >= needed: break print(f"[eval] collected {len(buffer)} tokens from {samples_consumed} samples in {time.time()-t0:.1f}s") total_ce = 0.0 total_tokens = 0 total_forward_time = 0.0 total_loop_ce = torch.zeros(config.num_loops) total_exit = torch.zeros(config.num_loops) print(f"[eval] running {EVAL_BATCHES} batches") with torch.no_grad(): for b in range(EVAL_BATCHES): off = b * EVAL_BATCH_SIZE * EVAL_SEQ_LEN chunk = buffer[off: off + EVAL_BATCH_SIZE * EVAL_SEQ_LEN] batch = torch.tensor(chunk, dtype=torch.long, device=device).view( EVAL_BATCH_SIZE, EVAL_SEQ_LEN ) t0 = time.time() # Inference path: returns (weighted-combined logits, exit_pdf) logits, exit_pdf = model(batch) # Cast logits to fp32 for CE — bf16 softmax can underflow on the rare-token tail. shift_logits = logits[:, :-1, :].contiguous().float() shift_labels = batch[:, 1:].contiguous() ce_tokens = F.cross_entropy( shift_logits.view(-1, config.vocab_size), shift_labels.view(-1), reduction="sum", ) total_ce += ce_tokens.item() total_tokens += shift_labels.numel() total_exit += exit_pdf.mean(dim=(0, 1)).detach() dt = time.time() - t0 total_forward_time += dt # Per-loop CE (inference-mode, unweighted, for diagnostics) with torch.no_grad(): # Re-run in labels mode to get per-loop stats without extra grad cost pass # skip — would double the forward cost. Use training logs for per-loop. avg = total_ce / total_tokens print( f" batch {b+1}/{EVAL_BATCHES}: ce={ce_tokens.item()/shift_labels.numel():.4f} " f"running_avg={avg:.4f} ppl={math.exp(avg):.2f} " f"forward={dt:.1f}s" ) avg_ce = total_ce / total_tokens ppl = math.exp(avg_ce) exit_mean = (total_exit / EVAL_BATCHES).tolist() print("\n=== Eval summary ===") print(f" checkpoint : {CKPT}") print(f" held-out dataset: {EVAL_DATASET}/{EVAL_DATASET_CONFIG} skip={EVAL_SKIP}") print(f" tokens evaluated: {total_tokens}") print(f" avg CE (nats) : {avg_ce:.4f}") print(f" perplexity : {ppl:.2f}") print(f" exit pdf : {[f'L{i}:{v:.2f}' for i, v in enumerate(exit_mean)]}") print(f" forward time : {total_forward_time:.1f}s total ({total_forward_time/EVAL_BATCHES:.1f}s/batch)") if __name__ == "__main__": main()