| """ |
| Interactive completion playground for BitLoopLM (CPU, bf16, FAST mode). |
| |
| Type a prompt and the model continues it. Empty line or Ctrl-C to exit. |
| |
| This is NOT a chat model — it's a base LM trained on cosmopedia. Best results |
| come from prompts that look like the start of an article or explanation: |
| "The capital of France is" |
| "Once upon a time, there was a" |
| "Photosynthesis is the process by which" |
| "import torch" |
| |
| Env vars: |
| CKPT checkpoint path (default: ./bitlooplm-checkpoints/pytorch_model.bin |
| or .../resume.pt — both work) |
| MODEL_SIZE small | tiny (default: small) |
| NUM_LOOPS recurrent loops (default: 4) |
| TOKENIZER tokenizer id (default: SmolLM2-135M) |
| MAX_NEW_TOKENS default 80 |
| TEMPERATURE default 0.8 (0 = greedy) |
| TOP_P default 0.9 |
| TOP_K default 0 (0 = disabled) |
| """ |
| import os |
| import sys |
| import time |
|
|
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, KVCache |
| from eval_cpu import freeze_bitlinears |
|
|
| CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin") |
| MODEL_SIZE = os.environ.get("MODEL_SIZE", "small") |
| NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4")) |
| TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M") |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "80")) |
| TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.8")) |
| TOP_P = float(os.environ.get("TOP_P", "0.9")) |
| TOP_K = int(os.environ.get("TOP_K", "0")) |
| USE_CACHE = os.environ.get("USE_CACHE", "1") == "1" |
| SEED = int(os.environ.get("SEED", "42")) |
|
|
|
|
| def load_model_for_inference(): |
| torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1)) |
| torch.set_float32_matmul_precision("medium") |
|
|
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
|
|
| cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE]) |
| cfg_dict["num_loops"] = NUM_LOOPS |
| config = BitLoopLMConfig(**cfg_dict) |
| model = BitLoopLM(config) |
| model.eval() |
|
|
| print(f"[chat] loading {CKPT} ...") |
| state = torch.load(CKPT, map_location="cpu", weights_only=False) |
| if isinstance(state, dict) and "model" in state: |
| state = state["model"] |
| model.load_state_dict(state, strict=False) |
|
|
| n_frozen = freeze_bitlinears(model) |
| model = model.to(torch.bfloat16) |
| print(f"[chat] frozen {n_frozen} BitLinears, model in bf16, ready") |
| return model, tokenizer, config |
|
|
|
|
| def sample_next(logits, temperature, top_p, top_k): |
| """Apply temperature, top-k, top-p; return one sampled token id (1-D tensor of len 1).""" |
| if temperature <= 0: |
| return torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| logits = logits.float() / temperature |
|
|
| if top_k > 0: |
| topk_vals, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) |
| cutoff = topk_vals[..., -1, None] |
| logits = torch.where(logits < cutoff, torch.full_like(logits, float("-inf")), logits) |
|
|
| if 0.0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| |
| |
| remove = cum_probs > top_p |
| remove[..., 1:] = remove[..., :-1].clone() |
| remove[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| return torch.multinomial(probs, num_samples=1) |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer, prompt, max_new_tokens, temperature, top_p, top_k, |
| use_cache=True, force_exit_loop=None): |
| ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) |
| print(prompt, end="", flush=True) |
|
|
| eos = tokenizer.eos_token_id |
| t_start = time.time() |
| n_generated = 0 |
|
|
| if use_cache: |
| cache = KVCache(num_slots=model.config.num_loops * model.config.num_hidden_layers) |
| |
| logits, _ = model(ids, kv_cache=cache, force_exit_loop=force_exit_loop) |
| next_logit = logits[0, -1, :] |
|
|
| for _ in range(max_new_tokens): |
| next_id = sample_next(next_logit, temperature, top_p, top_k) |
| token_int = next_id.item() |
| if token_int == eos: |
| break |
| n_generated += 1 |
| print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True) |
| |
| new_ids = next_id.view(1, 1) |
| logits, _ = model(new_ids, kv_cache=cache, force_exit_loop=force_exit_loop) |
| next_logit = logits[0, -1, :] |
| else: |
| for _ in range(max_new_tokens): |
| logits, _ = model(ids, force_exit_loop=force_exit_loop) |
| next_id = sample_next(logits[0, -1, :], temperature, top_p, top_k) |
| token_int = next_id.item() |
| if token_int == eos: |
| break |
| ids = torch.cat([ids, next_id.view(1, 1)], dim=-1) |
| n_generated += 1 |
| print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True) |
|
|
| elapsed = max(time.time() - t_start, 1e-6) |
| rate = n_generated / elapsed |
| print(f"\n\n[gen: {n_generated} tokens in {elapsed:.1f}s = {rate:.1f} tok/s]") |
|
|
|
|
| def main(): |
| model, tokenizer, config = load_model_for_inference() |
| last_loop = config.num_loops - 1 |
| |
| deep_idx = min(3, last_loop) |
| variants = [ |
| ("exit-gate (weighted)", None), |
| ("forced exit @ L0", 0), |
| (f"forced exit @ L{deep_idx}", deep_idx), |
| ] |
| print() |
| print("Interactive completion. Empty line or Ctrl-C to exit.") |
| print(f" max_new={MAX_NEW_TOKENS} temp={TEMPERATURE} top_p={TOP_P} top_k={TOP_K} cache={USE_CACHE} seed={SEED}") |
| print(f" vocab={config.vocab_size} loops={config.num_loops} hidden={config.hidden_size}") |
|
|
| try: |
| while True: |
| try: |
| prompt = input("\n> ") |
| except EOFError: |
| break |
| if not prompt.strip(): |
| break |
| for label, force_loop in variants: |
| print(f"\n--- {label} ---") |
| |
| |
| torch.manual_seed(SEED) |
| generate(model, tokenizer, prompt, MAX_NEW_TOKENS, TEMPERATURE, TOP_P, TOP_K, |
| use_cache=USE_CACHE, force_exit_loop=force_loop) |
| except KeyboardInterrupt: |
| pass |
| print("\nbye") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|