"""Minimal CPU inference server for BitLMv47B (PyTorch path, no AVX-512). For testing on hardware that can't run the optimized C binary (e.g. 1-vCPU QEMU with no AVX). Loads a small .pt checkpoint, exposes a streaming generation endpoint, and supports a built-in benchmark. Run: python _inf_server_simple.py --ckpt synth_5m.pt --port 5002 python _inf_server_simple.py --ckpt synth_5m.pt --bench # bench only """ import argparse, time, os, sys, math, json import torch import torch.nn.functional as F from tokenizers import Tokenizer from flask import Flask, Response, request, render_template_string from model_v47b import BitLMv47B import model_v16 as v16 def load_model(ckpt_path, device='cpu'): ck = torch.load(ckpt_path, map_location=device, weights_only=False) args = ck['args'] kw = {} for k in ('vocab_size', 'd_model', 'n_layers', 'n_heads', 'd_ff', 'slope_groups'): if k in args: kw[k] = args[k] if 'seq_len' in args: kw['max_seq_len'] = args['seq_len'] elif 'max_seq_len' in args: kw['max_seq_len'] = args['max_seq_len'] m = BitLMv47B(**kw) sd = {k.replace('._orig_mod.', '.'): v for k, v in ck['model'].items()} miss, unexp = m.load_state_dict(sd, strict=False) print(f'load: missing={len(miss)} unexpected={len(unexp)} ' f'val_bpc={ck.get("val_bpc")}', flush=True) m.eval().to(device) return m, kw @torch.no_grad() def generate(m, tok, prompt, n_new=64, temp=0.0, top_k=0, device='cpu'): v16.set_gumbel_tau(0.1) enc = tok.encode(prompt) ids = enc.ids bos = tok.token_to_id('[BOS]') if bos is not None and (not ids or ids[0] != bos): ids = [bos] + ids x = torch.tensor([ids], dtype=torch.long, device=device) out_ids = list(ids) max_seq = m.max_seq_len if hasattr(m, 'max_seq_len') else 2048 for _ in range(n_new): if x.size(1) >= max_seq: break logits, _ = m(x, None) last = logits[0, -1] if temp <= 0: nxt = int(last.argmax().item()) else: probs = F.softmax(last / temp, dim=-1) if top_k > 0: v, idx = probs.topk(top_k) nxt = int(idx[torch.multinomial(v, 1).item()].item()) else: nxt = int(torch.multinomial(probs, 1).item()) out_ids.append(nxt) x = torch.cat([x, torch.tensor([[nxt]], device=device)], dim=1) return out_ids def benchmark(m, tok, n_new=64, prompt='The quick brown fox', warmup=2, runs=3): print(f'Benchmark: prompt={prompt!r} n_new={n_new}', flush=True) enc = tok.encode(prompt) ids = enc.ids print(f'prompt tokens: {len(ids)}', flush=True) # Warmup for _ in range(warmup): _ = generate(m, tok, prompt, n_new=8) # Time it times = [] for r in range(runs): t0 = time.time() out = generate(m, tok, prompt, n_new=n_new) elapsed = time.time() - t0 n_gen = len(out) - len(ids) rate = n_gen / elapsed times.append(rate) print(f' run {r+1}: {n_gen} tok in {elapsed:.2f}s = {rate:.2f} tok/s', flush=True) avg = sum(times) / len(times) print(f'\nAVG: {avg:.2f} tok/s ({n_new}-token generation, ' f'{m.d_model}-d {m.n_layers}-layer model)', flush=True) return avg HTML = """