bitnet-1bitllm / _inf_server_simple.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Minimal CPU inference server for BitLMv47B (PyTorch path, no AVX-512).
For testing on hardware that can't run the optimized C binary
(e.g. 1-vCPU QEMU with no AVX). Loads a small .pt checkpoint, exposes a
streaming generation endpoint, and supports a built-in benchmark.
Run:
python _inf_server_simple.py --ckpt synth_5m.pt --port 5002
python _inf_server_simple.py --ckpt synth_5m.pt --bench # bench only
"""
import argparse, time, os, sys, math, json
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from flask import Flask, Response, request, render_template_string
from model_v47b import BitLMv47B
import model_v16 as v16
def load_model(ckpt_path, device='cpu'):
ck = torch.load(ckpt_path, map_location=device, weights_only=False)
args = ck['args']
kw = {}
for k in ('vocab_size', 'd_model', 'n_layers', 'n_heads', 'd_ff', 'slope_groups'):
if k in args:
kw[k] = args[k]
if 'seq_len' in args:
kw['max_seq_len'] = args['seq_len']
elif 'max_seq_len' in args:
kw['max_seq_len'] = args['max_seq_len']
m = BitLMv47B(**kw)
sd = {k.replace('._orig_mod.', '.'): v for k, v in ck['model'].items()}
miss, unexp = m.load_state_dict(sd, strict=False)
print(f'load: missing={len(miss)} unexpected={len(unexp)} '
f'val_bpc={ck.get("val_bpc")}', flush=True)
m.eval().to(device)
return m, kw
@torch.no_grad()
def generate(m, tok, prompt, n_new=64, temp=0.0, top_k=0, device='cpu'):
v16.set_gumbel_tau(0.1)
enc = tok.encode(prompt)
ids = enc.ids
bos = tok.token_to_id('[BOS]')
if bos is not None and (not ids or ids[0] != bos):
ids = [bos] + ids
x = torch.tensor([ids], dtype=torch.long, device=device)
out_ids = list(ids)
max_seq = m.max_seq_len if hasattr(m, 'max_seq_len') else 2048
for _ in range(n_new):
if x.size(1) >= max_seq: break
logits, _ = m(x, None)
last = logits[0, -1]
if temp <= 0:
nxt = int(last.argmax().item())
else:
probs = F.softmax(last / temp, dim=-1)
if top_k > 0:
v, idx = probs.topk(top_k)
nxt = int(idx[torch.multinomial(v, 1).item()].item())
else:
nxt = int(torch.multinomial(probs, 1).item())
out_ids.append(nxt)
x = torch.cat([x, torch.tensor([[nxt]], device=device)], dim=1)
return out_ids
def benchmark(m, tok, n_new=64, prompt='The quick brown fox', warmup=2, runs=3):
print(f'Benchmark: prompt={prompt!r} n_new={n_new}', flush=True)
enc = tok.encode(prompt)
ids = enc.ids
print(f'prompt tokens: {len(ids)}', flush=True)
# Warmup
for _ in range(warmup):
_ = generate(m, tok, prompt, n_new=8)
# Time it
times = []
for r in range(runs):
t0 = time.time()
out = generate(m, tok, prompt, n_new=n_new)
elapsed = time.time() - t0
n_gen = len(out) - len(ids)
rate = n_gen / elapsed
times.append(rate)
print(f' run {r+1}: {n_gen} tok in {elapsed:.2f}s = {rate:.2f} tok/s',
flush=True)
avg = sum(times) / len(times)
print(f'\nAVG: {avg:.2f} tok/s ({n_new}-token generation, '
f'{m.d_model}-d {m.n_layers}-layer model)', flush=True)
return avg
HTML = """<!doctype html>
<html><head><title>BitLM CPU Inference</title>
<style>body{font-family:sans-serif;max-width:800px;margin:2em auto;}
textarea{width:100%;min-height:120px;}.gen{color:#0a0;}</style></head>
<body><h2>BitLM CPU Inference (synth_5m, no AVX)</h2>
<form id=f><textarea id=p>The quick brown fox</textarea>
<input type=number id=n value=64 min=1 max=512> tokens
<button>generate</button></form>
<pre id=o></pre>
<script>
document.getElementById('f').addEventListener('submit',async e=>{
e.preventDefault();let p=document.getElementById('p').value;
let n=document.getElementById('n').value;let o=document.getElementById('o');
o.innerText='';let r=await fetch('/gen',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({prompt:p,n_new:+n})});
let rd=r.body.getReader();while(1){let{done,value}=await rd.read();if(done)break;
o.innerText+=new TextDecoder().decode(value);}});
</script></body></html>"""
def make_app(ckpt_path, tokenizer_path):
m, kw = load_model(ckpt_path)
tok = Tokenizer.from_file(tokenizer_path)
app = Flask(__name__)
@app.route('/')
def root():
return render_template_string(HTML)
@app.route('/gen', methods=['POST'])
def gen():
d = request.get_json()
prompt = d.get('prompt', '')
n_new = int(d.get('n_new', 64))
def stream():
yield prompt
yield ' [GEN] '
t0 = time.time()
out = generate(m, tok, prompt, n_new=n_new)
elapsed = time.time() - t0
new_ids = out[len(tok.encode(prompt).ids) + (1 if tok.token_to_id('[BOS]') is not None else 0):]
yield tok.decode(new_ids) if new_ids else ''
yield f'\n\n[{len(new_ids)} tok in {elapsed:.2f}s = {len(new_ids)/max(elapsed,1e-3):.1f} tok/s]'
return Response(stream(), mimetype='text/plain')
return app
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--tokenizer', default='rdv4_tokenizer.json')
ap.add_argument('--port', type=int, default=5002)
ap.add_argument('--bench', action='store_true')
ap.add_argument('--n-new', type=int, default=64)
args = ap.parse_args()
if args.bench:
m, _ = load_model(args.ckpt)
tok = Tokenizer.from_file(args.tokenizer)
benchmark(m, tok, n_new=args.n_new)
return
app = make_app(args.ckpt, args.tokenizer)
print(f'serving on http://0.0.0.0:{args.port}', flush=True)
app.run(host='0.0.0.0', port=args.port, debug=False, threaded=False)
if __name__ == '__main__':
main()