| """Minimal CPU inference server for BitLMv47B (PyTorch path, no AVX-512). |
| |
| For testing on hardware that can't run the optimized C binary |
| (e.g. 1-vCPU QEMU with no AVX). Loads a small .pt checkpoint, exposes a |
| streaming generation endpoint, and supports a built-in benchmark. |
| |
| Run: |
| python _inf_server_simple.py --ckpt synth_5m.pt --port 5002 |
| python _inf_server_simple.py --ckpt synth_5m.pt --bench # bench only |
| """ |
| import argparse, time, os, sys, math, json |
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from flask import Flask, Response, request, render_template_string |
|
|
| from model_v47b import BitLMv47B |
| import model_v16 as v16 |
|
|
|
|
| def load_model(ckpt_path, device='cpu'): |
| ck = torch.load(ckpt_path, map_location=device, weights_only=False) |
| args = ck['args'] |
| kw = {} |
| for k in ('vocab_size', 'd_model', 'n_layers', 'n_heads', 'd_ff', 'slope_groups'): |
| if k in args: |
| kw[k] = args[k] |
| if 'seq_len' in args: |
| kw['max_seq_len'] = args['seq_len'] |
| elif 'max_seq_len' in args: |
| kw['max_seq_len'] = args['max_seq_len'] |
| m = BitLMv47B(**kw) |
| sd = {k.replace('._orig_mod.', '.'): v for k, v in ck['model'].items()} |
| miss, unexp = m.load_state_dict(sd, strict=False) |
| print(f'load: missing={len(miss)} unexpected={len(unexp)} ' |
| f'val_bpc={ck.get("val_bpc")}', flush=True) |
| m.eval().to(device) |
| return m, kw |
|
|
|
|
| @torch.no_grad() |
| def generate(m, tok, prompt, n_new=64, temp=0.0, top_k=0, device='cpu'): |
| v16.set_gumbel_tau(0.1) |
| enc = tok.encode(prompt) |
| ids = enc.ids |
| bos = tok.token_to_id('[BOS]') |
| if bos is not None and (not ids or ids[0] != bos): |
| ids = [bos] + ids |
| x = torch.tensor([ids], dtype=torch.long, device=device) |
| out_ids = list(ids) |
| max_seq = m.max_seq_len if hasattr(m, 'max_seq_len') else 2048 |
| for _ in range(n_new): |
| if x.size(1) >= max_seq: break |
| logits, _ = m(x, None) |
| last = logits[0, -1] |
| if temp <= 0: |
| nxt = int(last.argmax().item()) |
| else: |
| probs = F.softmax(last / temp, dim=-1) |
| if top_k > 0: |
| v, idx = probs.topk(top_k) |
| nxt = int(idx[torch.multinomial(v, 1).item()].item()) |
| else: |
| nxt = int(torch.multinomial(probs, 1).item()) |
| out_ids.append(nxt) |
| x = torch.cat([x, torch.tensor([[nxt]], device=device)], dim=1) |
| return out_ids |
|
|
|
|
| def benchmark(m, tok, n_new=64, prompt='The quick brown fox', warmup=2, runs=3): |
| print(f'Benchmark: prompt={prompt!r} n_new={n_new}', flush=True) |
| enc = tok.encode(prompt) |
| ids = enc.ids |
| print(f'prompt tokens: {len(ids)}', flush=True) |
| |
| for _ in range(warmup): |
| _ = generate(m, tok, prompt, n_new=8) |
| |
| times = [] |
| for r in range(runs): |
| t0 = time.time() |
| out = generate(m, tok, prompt, n_new=n_new) |
| elapsed = time.time() - t0 |
| n_gen = len(out) - len(ids) |
| rate = n_gen / elapsed |
| times.append(rate) |
| print(f' run {r+1}: {n_gen} tok in {elapsed:.2f}s = {rate:.2f} tok/s', |
| flush=True) |
| avg = sum(times) / len(times) |
| print(f'\nAVG: {avg:.2f} tok/s ({n_new}-token generation, ' |
| f'{m.d_model}-d {m.n_layers}-layer model)', flush=True) |
| return avg |
|
|
|
|
| HTML = """<!doctype html> |
| <html><head><title>BitLM CPU Inference</title> |
| <style>body{font-family:sans-serif;max-width:800px;margin:2em auto;} |
| textarea{width:100%;min-height:120px;}.gen{color:#0a0;}</style></head> |
| <body><h2>BitLM CPU Inference (synth_5m, no AVX)</h2> |
| <form id=f><textarea id=p>The quick brown fox</textarea> |
| <input type=number id=n value=64 min=1 max=512> tokens |
| <button>generate</button></form> |
| <pre id=o></pre> |
| <script> |
| document.getElementById('f').addEventListener('submit',async e=>{ |
| e.preventDefault();let p=document.getElementById('p').value; |
| let n=document.getElementById('n').value;let o=document.getElementById('o'); |
| o.innerText='';let r=await fetch('/gen',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({prompt:p,n_new:+n})}); |
| let rd=r.body.getReader();while(1){let{done,value}=await rd.read();if(done)break; |
| o.innerText+=new TextDecoder().decode(value);}}); |
| </script></body></html>""" |
|
|
|
|
| def make_app(ckpt_path, tokenizer_path): |
| m, kw = load_model(ckpt_path) |
| tok = Tokenizer.from_file(tokenizer_path) |
| app = Flask(__name__) |
|
|
| @app.route('/') |
| def root(): |
| return render_template_string(HTML) |
|
|
| @app.route('/gen', methods=['POST']) |
| def gen(): |
| d = request.get_json() |
| prompt = d.get('prompt', '') |
| n_new = int(d.get('n_new', 64)) |
| def stream(): |
| yield prompt |
| yield ' [GEN] ' |
| t0 = time.time() |
| out = generate(m, tok, prompt, n_new=n_new) |
| elapsed = time.time() - t0 |
| new_ids = out[len(tok.encode(prompt).ids) + (1 if tok.token_to_id('[BOS]') is not None else 0):] |
| yield tok.decode(new_ids) if new_ids else '' |
| yield f'\n\n[{len(new_ids)} tok in {elapsed:.2f}s = {len(new_ids)/max(elapsed,1e-3):.1f} tok/s]' |
| return Response(stream(), mimetype='text/plain') |
|
|
| return app |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--ckpt', required=True) |
| ap.add_argument('--tokenizer', default='rdv4_tokenizer.json') |
| ap.add_argument('--port', type=int, default=5002) |
| ap.add_argument('--bench', action='store_true') |
| ap.add_argument('--n-new', type=int, default=64) |
| args = ap.parse_args() |
|
|
| if args.bench: |
| m, _ = load_model(args.ckpt) |
| tok = Tokenizer.from_file(args.tokenizer) |
| benchmark(m, tok, n_new=args.n_new) |
| return |
|
|
| app = make_app(args.ckpt, args.tokenizer) |
| print(f'serving on http://0.0.0.0:{args.port}', flush=True) |
| app.run(host='0.0.0.0', port=args.port, debug=False, threaded=False) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|