bitlooplm-small / chat_cpu.py
wmertens's picture
learnings
34f2e1c
"""
Interactive completion playground for BitLoopLM (CPU, bf16, FAST mode).
Type a prompt and the model continues it. Empty line or Ctrl-C to exit.
This is NOT a chat model — it's a base LM trained on cosmopedia. Best results
come from prompts that look like the start of an article or explanation:
"The capital of France is"
"Once upon a time, there was a"
"Photosynthesis is the process by which"
"import torch"
Env vars:
CKPT checkpoint path (default: ./bitlooplm-checkpoints/pytorch_model.bin
or .../resume.pt — both work)
MODEL_SIZE small | tiny (default: small)
NUM_LOOPS recurrent loops (default: 4)
TOKENIZER tokenizer id (default: SmolLM2-135M)
MAX_NEW_TOKENS default 80
TEMPERATURE default 0.8 (0 = greedy)
TOP_P default 0.9
TOP_K default 0 (0 = disabled)
"""
import os
import sys
import time
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_bitlooplm_standalone import BitLoopLM, BitLoopLMConfig, MODEL_CONFIGS, KVCache
from eval_cpu import freeze_bitlinears
CKPT = os.environ.get("CKPT", "./bitlooplm-checkpoints/pytorch_model.bin")
MODEL_SIZE = os.environ.get("MODEL_SIZE", "small")
NUM_LOOPS = int(os.environ.get("NUM_LOOPS", "4"))
TOKENIZER = os.environ.get("TOKENIZER", "HuggingFaceTB/SmolLM2-135M")
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "80"))
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.8"))
TOP_P = float(os.environ.get("TOP_P", "0.9"))
TOP_K = int(os.environ.get("TOP_K", "0"))
USE_CACHE = os.environ.get("USE_CACHE", "1") == "1"
SEED = int(os.environ.get("SEED", "42"))
def load_model_for_inference():
torch.set_num_threads(max(1, (os.cpu_count() or 4) - 1))
torch.set_float32_matmul_precision("medium")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
cfg_dict = dict(MODEL_CONFIGS[MODEL_SIZE])
cfg_dict["num_loops"] = NUM_LOOPS
config = BitLoopLMConfig(**cfg_dict)
model = BitLoopLM(config)
model.eval()
print(f"[chat] loading {CKPT} ...")
state = torch.load(CKPT, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model" in state:
state = state["model"]
model.load_state_dict(state, strict=False)
n_frozen = freeze_bitlinears(model)
model = model.to(torch.bfloat16)
print(f"[chat] frozen {n_frozen} BitLinears, model in bf16, ready")
return model, tokenizer, config
def sample_next(logits, temperature, top_p, top_k):
"""Apply temperature, top-k, top-p; return one sampled token id (1-D tensor of len 1)."""
if temperature <= 0:
return torch.argmax(logits, dim=-1, keepdim=True)
logits = logits.float() / temperature
if top_k > 0:
topk_vals, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
cutoff = topk_vals[..., -1, None]
logits = torch.where(logits < cutoff, torch.full_like(logits, float("-inf")), logits)
if 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Tokens beyond the nucleus are removed; shift so the first token above the
# threshold is kept (otherwise the token that pushes cum_probs past top_p
# is dropped, which over-truncates).
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits)
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens, temperature, top_p, top_k,
use_cache=True, force_exit_loop=None):
ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
print(prompt, end="", flush=True)
eos = tokenizer.eos_token_id
t_start = time.time()
n_generated = 0
if use_cache:
cache = KVCache(num_slots=model.config.num_loops * model.config.num_hidden_layers)
# Prefill: run the prompt through once, populate cache.
logits, _ = model(ids, kv_cache=cache, force_exit_loop=force_exit_loop)
next_logit = logits[0, -1, :]
for _ in range(max_new_tokens):
next_id = sample_next(next_logit, temperature, top_p, top_k)
token_int = next_id.item()
if token_int == eos:
break
n_generated += 1
print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True)
# Decode step: feed only the new token, cache holds the rest.
new_ids = next_id.view(1, 1)
logits, _ = model(new_ids, kv_cache=cache, force_exit_loop=force_exit_loop)
next_logit = logits[0, -1, :]
else:
for _ in range(max_new_tokens):
logits, _ = model(ids, force_exit_loop=force_exit_loop)
next_id = sample_next(logits[0, -1, :], temperature, top_p, top_k)
token_int = next_id.item()
if token_int == eos:
break
ids = torch.cat([ids, next_id.view(1, 1)], dim=-1)
n_generated += 1
print(tokenizer.decode([token_int], skip_special_tokens=True), end="", flush=True)
elapsed = max(time.time() - t_start, 1e-6)
rate = n_generated / elapsed
print(f"\n\n[gen: {n_generated} tokens in {elapsed:.1f}s = {rate:.1f} tok/s]")
def main():
model, tokenizer, config = load_model_for_inference()
last_loop = config.num_loops - 1
# Probe shallowest loop and an L3-ish deep one (clamped to the actual depth).
deep_idx = min(3, last_loop)
variants = [
("exit-gate (weighted)", None),
("forced exit @ L0", 0),
(f"forced exit @ L{deep_idx}", deep_idx),
]
print()
print("Interactive completion. Empty line or Ctrl-C to exit.")
print(f" max_new={MAX_NEW_TOKENS} temp={TEMPERATURE} top_p={TOP_P} top_k={TOP_K} cache={USE_CACHE} seed={SEED}")
print(f" vocab={config.vocab_size} loops={config.num_loops} hidden={config.hidden_size}")
try:
while True:
try:
prompt = input("\n> ")
except EOFError:
break
if not prompt.strip():
break
for label, force_loop in variants:
print(f"\n--- {label} ---")
# Reseed before every variant so sampling noise is held fixed and
# any divergence is attributable to the exit choice, not the RNG.
torch.manual_seed(SEED)
generate(model, tokenizer, prompt, MAX_NEW_TOKENS, TEMPERATURE, TOP_P, TOP_K,
use_cache=USE_CACHE, force_exit_loop=force_loop)
except KeyboardInterrupt:
pass
print("\nbye")
if __name__ == "__main__":
main()