| |
| """ |
| measure.py — the benchmark harness. Hits an OpenAI-compatible endpoint (the one |
| `vllm serve` exposes) and records the three demo numbers: |
| |
| tokens/sec (decode throughput) <- THE WIN |
| TTFT (time to first token) <- should be ~unchanged with DFlash |
| acceptance length tau <- WHY it's faster (read from vLLM metrics) |
| |
| Run it twice on the GPU host — once against the baseline server, once against the |
| DFlash server — and diff the JSON. That diff IS the before/after table. |
| |
| This file is endpoint-driven, so it runs anywhere (including the Mac) AS LONG AS |
| something is serving on --base-url. On the Mac you can point it at a local |
| tiny-model OpenAI server to shape-test; on the GPU host you point it at vLLM. |
| |
| acceptance length tau: |
| tau = mean(number of tokens committed per target forward pass). |
| With a draft of gamma=7, tau ranges from 1 (everything rejected, +1 bonus) |
| up to gamma+1=8 (all accepted + bonus). The DFlash card publishes per-position |
| acceptance only (~70.7% at position 1, decaying to ~2% by position 7), NOT a |
| tau figure -- measure tau on the GPU host (expect roughly 2-3). vLLM exposes |
| accepted/draft counts in its metrics; we |
| read them from /metrics (Prometheus) when present and otherwise estimate tau |
| from the speedup. VERIFY AT ONBOARDING which metric names the vLLM build uses |
| (e.g. vllm:spec_decode_num_accepted_tokens / _num_draft_tokens). |
| |
| Usage: |
| python bench/measure.py --base-url http://localhost:8000 --model laguna \ |
| --label dflash --out results/dflash.json --n 20 |
| python bench/measure.py --base-url http://localhost:8000 --model laguna \ |
| --label baseline --out results/baseline.json --n 20 |
| |
| Requires only stdlib + requests-free urllib, so no extra extra deps. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| import urllib.request |
| from statistics import mean |
|
|
| PROMPTS = [ |
| "Write a Python function that returns the nth Fibonacci number iteratively.", |
| "Implement binary search over a sorted list in Python. Return the index or -1.", |
| "Write a function to check if a string is a palindrome, ignoring case and spaces.", |
| "Implement quicksort in Python.", |
| "Write a function that merges two sorted lists into one sorted list.", |
| ] |
|
|
|
|
| def _post(url: str, payload: dict) -> dict: |
| data = json.dumps(payload).encode() |
| req = urllib.request.Request(url, data=data, |
| headers={"Content-Type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=600) as r: |
| return json.loads(r.read().decode()) |
|
|
|
|
| def _try_metrics(base_url: str) -> dict: |
| """Best-effort read of vLLM Prometheus spec-decode counters.""" |
| out = {} |
| try: |
| with urllib.request.urlopen(base_url.rstrip("/") + "/metrics", timeout=10) as r: |
| text = r.read().decode() |
| except Exception: |
| return out |
| for line in text.splitlines(): |
| if line.startswith("#"): |
| continue |
| |
| for key in ("spec_decode_num_accepted_tokens", |
| "spec_decode_num_draft_tokens", |
| "spec_decode_num_emitted_tokens"): |
| if key in line: |
| try: |
| out[key] = float(line.split()[-1]) |
| except ValueError: |
| pass |
| return out |
|
|
|
|
| def measure_one(base_url: str, model: str, prompt: str, max_tokens: int) -> dict: |
| url = base_url.rstrip("/") + "/v1/completions" |
| |
| |
| payload = { |
| "model": model, |
| "prompt": prompt, |
| "max_tokens": max_tokens, |
| "temperature": 0.0, |
| "stream": True, |
| } |
| data = json.dumps(payload).encode() |
| req = urllib.request.Request(url, data=data, |
| headers={"Content-Type": "application/json"}) |
| t0 = time.perf_counter() |
| ttft = None |
| n_tokens = 0 |
| chunks = [] |
| with urllib.request.urlopen(req, timeout=600) as r: |
| for raw in r: |
| line = raw.decode().strip() |
| if not line or not line.startswith("data:"): |
| continue |
| body = line[len("data:"):].strip() |
| if body == "[DONE]": |
| break |
| obj = json.loads(body) |
| piece = obj.get("choices", [{}])[0].get("text", "") |
| if piece: |
| if ttft is None: |
| ttft = time.perf_counter() - t0 |
| n_tokens += 1 |
| chunks.append(piece) |
| total = time.perf_counter() - t0 |
| decode_time = max(total - (ttft or 0.0), 1e-9) |
| tps = (n_tokens - 1) / decode_time if n_tokens > 1 else 0.0 |
| return { |
| "ttft_s": ttft, |
| "total_s": total, |
| "new_tokens": n_tokens, |
| "tokens_per_s": tps, |
| "text": "".join(chunks), |
| } |
|
|
|
|
| def main() -> None: |
| p = argparse.ArgumentParser(description="Benchmark tokens/sec, TTFT, acceptance length against a vLLM endpoint.") |
| p.add_argument("--base-url", default="http://localhost:8000") |
| p.add_argument("--model", default="laguna") |
| p.add_argument("--label", required=True, help="baseline | dflash (used in the output).") |
| p.add_argument("--n", type=int, default=20, help="Number of generations (cycles through the prompt set).") |
| p.add_argument("--max-tokens", type=int, default=256) |
| p.add_argument("--out", default=None, help="Write JSON here (e.g. results/dflash.json).") |
| args = p.parse_args() |
|
|
| before = _try_metrics(args.base_url) |
| runs = [] |
| for i in range(args.n): |
| prompt = PROMPTS[i % len(PROMPTS)] |
| runs.append(measure_one(args.base_url, args.model, prompt, args.max_tokens)) |
| print(f" [{args.label}] run {i+1}/{args.n} " |
| f"tps={runs[-1]['tokens_per_s']:.1f} ttft={runs[-1]['ttft_s']:.3f}s") |
| after = _try_metrics(args.base_url) |
|
|
| |
| tau = None |
| acc = after.get("spec_decode_num_accepted_tokens", 0) - before.get("spec_decode_num_accepted_tokens", 0) |
| emitted = after.get("spec_decode_num_emitted_tokens", 0) - before.get("spec_decode_num_emitted_tokens", 0) |
| draft = after.get("spec_decode_num_draft_tokens", 0) - before.get("spec_decode_num_draft_tokens", 0) |
| |
| |
| if draft > 0: |
| passes = draft / NUM_SPECULATIVE_TOKENS |
| committed = acc + passes |
| tau = committed / passes if passes > 0 else None |
|
|
| summary = { |
| "label": args.label, |
| "model": args.model, |
| "base_url": args.base_url, |
| "n": args.n, |
| "tokens_per_s_mean": mean(r["tokens_per_s"] for r in runs), |
| "ttft_s_mean": mean(r["ttft_s"] for r in runs if r["ttft_s"] is not None), |
| "acceptance_length_tau": tau, |
| "spec_metrics_before": before, |
| "spec_metrics_after": after, |
| "runs": runs, |
| } |
| print(json.dumps({k: v for k, v in summary.items() if k != "runs"}, indent=2)) |
| if args.out: |
| os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) |
| with open(args.out, "w") as f: |
| json.dump(summary, f, indent=2) |
| print(f"[measure] wrote {args.out}") |
|
|
|
|
| NUM_SPECULATIVE_TOKENS = 7 |
|
|
| if __name__ == "__main__": |
| main() |
|
|