""" Interactive completion playground for BitLoopLM (CPU, bf16, FAST mode). Type a prompt and the model continues it. Empty line or Ctrl-C to exit. This is NOT a chat model — it's a base LM trained on cosmopedia. Best results come from prompts that look like the start of an article or explanation: "The capital of France is" "Once upon a time, there was a" "Photosynthesis is the process by which" "import torch" Env vars: CKPT checkpoint path (default: ./bitlooplm-checkpoints/pytorch_model.bin or .../resume.pt — both work) MODEL_SIZE small | tiny (default: small) NUM_LOOPS recurrent loops (default: 4) TOKENIZER tokenizer id (default: SmolLM2-135M) MAX_NEW_TOKENS default 80 TEMPERATURE default 0.8 (0 = greedy) TOP_P default 0.9 TOP_K default 0 (0 = disabled) """ import os import sys import time import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, KVCache from eval_cpu import freeze_bitlinears CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin") MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M") MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "80")) TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.8")) TOP_P = float(os.environ.get("TOP_P", "0.9")) TOP_K = int(os.environ.get("TOP_K", "0")) USE_CACHE = os.environ.get("USE_CACHE", "1") == "1" SEED = int(os.environ.get("SEED", "42")) def load_model_for_inference(): torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1)) torch.set_float32_matmul_precision("medium") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) cfg_dict["num_loops"] = NUM_LOOPS config = BitLoopLMConfig(**cfg_dict) model = BitLoopLM(config) model.eval() print(f"[chat] loading {CKPT} ...") state = torch.load(CKPT, map_location="cpu", weights_only=False) if isinstance(state, dict) and "model" in state: state = state["model"] model.load_state_dict(state, strict=False) n_frozen = freeze_bitlinears(model) model = model.to(torch.bfloat16) print(f"[chat] frozen {n_frozen} BitLinears, model in bf16, ready") return model, tokenizer, config def sample_next(logits, temperature, top_p, top_k): """Apply temperature, top-k, top-p; return one sampled token id (1-D tensor of len 1).""" if temperature <= 0: return torch.argmax(logits, dim=-1, keepdim=True) logits = logits.float() / temperature if top_k > 0: topk_vals, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) cutoff = topk_vals[..., -1, None] logits = torch.where(logits < cutoff, torch.full_like(logits, float("-inf")), logits) if 0.0 < top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Tokens beyond the nucleus are removed; shift so the first token above the # threshold is kept (otherwise the token that pushes cum_probs past top_p # is dropped, which over-truncates). remove = cum_probs > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1) @torch.no_grad() def generate(model, tokenizer, prompt, max_new_tokens, temperature, top_p, top_k, use_cache=True, force_exit_loop=None): ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) print(prompt, end="", flush=True) eos = tokenizer.eos_token_id t_start = time.time() n_generated = 0 if use_cache: cache = KVCache(num_slots=model.config.num_loops * model.config.num_hidden_layers) # Prefill: run the prompt through once, populate cache. logits, _ = model(ids, kv_cache=cache, force_exit_loop=force_exit_loop) next_logit = logits[0, -1, :] for _ in range(max_new_tokens): next_id = sample_next(next_logit, temperature, top_p, top_k) token_int = next_id.item() if token_int == eos: break n_generated += 1 print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True) # Decode step: feed only the new token, cache holds the rest. new_ids = next_id.view(1, 1) logits, _ = model(new_ids, kv_cache=cache, force_exit_loop=force_exit_loop) next_logit = logits[0, -1, :] else: for _ in range(max_new_tokens): logits, _ = model(ids, force_exit_loop=force_exit_loop) next_id = sample_next(logits[0, -1, :], temperature, top_p, top_k) token_int = next_id.item() if token_int == eos: break ids = torch.cat([ids, next_id.view(1, 1)], dim=-1) n_generated += 1 print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True) elapsed = max(time.time() - t_start, 1e-6) rate = n_generated / elapsed print(f"\n\n[gen: {n_generated} tokens in {elapsed:.1f}s = {rate:.1f} tok/s]") def main(): model, tokenizer, config = load_model_for_inference() last_loop = config.num_loops - 1 # Probe shallowest loop and an L3-ish deep one (clamped to the actual depth). deep_idx = min(3, last_loop) variants = [ ("exit-gate (weighted)", None), ("forced exit @ L0", 0), (f"forced exit @ L{deep_idx}", deep_idx), ] print() print("Interactive completion. Empty line or Ctrl-C to exit.") print(f" max_new={MAX_NEW_TOKENS} temp={TEMPERATURE} top_p={TOP_P} top_k={TOP_K} cache={USE_CACHE} seed={SEED}") print(f" vocab={config.vocab_size} loops={config.num_loops} hidden={config.hidden_size}") try: while True: try: prompt = input("\n> ") except EOFError: break if not prompt.strip(): break for label, force_loop in variants: print(f"\n--- {label} ---") # Reseed before every variant so sampling noise is held fixed and # any divergence is attributable to the exit choice, not the RNG. torch.manual_seed(SEED) generate(model, tokenizer, prompt, MAX_NEW_TOKENS, TEMPERATURE, TOP_P, TOP_K, use_cache=USE_CACHE, force_exit_loop=force_loop) except KeyboardInterrupt: pass print("\nbye") if __name__ == "__main__": main()