| |
| """ |
| rollout_bench.py — the COMBINED-THESIS benchmark. It measures the same endpoint |
| that verifiers points its RL rollouts at (see configs/endpoints.toml), but frames |
| the numbers the way an RL post-training run cares about: |
| |
| rollout throughput (completions/sec, tokens/sec) <- THE WIN |
| TTFT (time to first token) <- ~unchanged with DFlash |
| acceptance length tau <- WHY it's faster |
| projected $/run saved <- WHY it's CHEAPER |
| |
| The thesis is "lossless DFlash speculative decoding makes RL post-training |
| cheaper." RL spends most of its wall-clock generating rollouts, so a faster |
| rollout endpoint — at IDENTICAL greedy output — buys the same reward curve for |
| fewer GPU-hours. This script measures that, live, against whatever is serving on |
| --base-url. It is a sibling of measure.py and reuses the same conventions: |
| stdlib urllib only, streaming /v1/completions, greedy decode, best-effort read of |
| vLLM /metrics. The ONE design rule: baseline vs DFlash is a one-flag swap on the |
| SERVER (serve_vllm.py --mode), never a change here — so the same command produces |
| both halves of the A/B. |
| |
| Workload: an RL "rollout batch" = a fixed prompt set, replayed identically, with |
| --rollouts-per-example completions per prompt. The workload is deterministic |
| (temperature 0 by default) so the BASELINE and DFLASH runs do identical work and |
| the only thing that moves is speed. |
| |
| acceptance length tau: |
| tau = mean tokens committed per target forward pass. With gamma=7 it ranges |
| from 1 (all drafts rejected, +1 bonus) to 8 (all accepted + bonus). tau is NOT |
| published in any Laguna/DFlash primary source — the model card gives per-position |
| acceptance rates only (position-1 ~70.7%, decaying to ~2% at position-7). So we |
| MEASURE it here from vLLM /metrics deltas. Expect roughly 2-3; never quote a |
| published figure. None is printed if /metrics is unavailable — read it off the |
| server's /metrics by hand then. VERIFY the exact metric names at onboarding. |
| |
| Losslessness: |
| --assert-parity runs the deterministic (greedy) workload TWICE against the same |
| endpoint and asserts byte-identical completions. On a correct speculative-decoding |
| implementation greedy output is invariant, so two runs must match. (The |
| baseline-vs-DFlash cross-server parity check lives in evals/humaneval_subset.py |
| --parity; this in-run check guards against nondeterminism in the served config.) |
| |
| This does NOT fabricate anything. Every number comes from live HTTP calls. If the |
| endpoint is down you get an error, not a made-up result. |
| |
| Usage: |
| # measure a DFlash run and project savings at $3.50/GPU-hour |
| python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ |
| --label dflash --prompts 8 --rollouts-per-example 8 --max-tokens 512 \\ |
| --hourly-rate 3.50 --out results/rollout_dflash.json |
| |
| # measure the baseline (re-serve with serve_vllm.py --mode baseline first) |
| python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ |
| --label baseline --hourly-rate 3.50 --out results/rollout_baseline.json |
| |
| # prove losslessness: two greedy runs against the same endpoint must be identical |
| python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ |
| --label dflash --assert-parity |
| |
| Requires only the stdlib (urllib), so no extra venue deps. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| import urllib.request |
| from statistics import mean |
|
|
| |
| |
| NUM_SPECULATIVE_TOKENS = 7 |
|
|
| |
| |
| PROMPTS = [ |
| "Write a Python function that returns the nth Fibonacci number iteratively.", |
| "Implement binary search over a sorted list in Python. Return the index or -1.", |
| "Write a function to check if a string is a palindrome, ignoring case and spaces.", |
| "Implement quicksort in Python.", |
| "Write a function that merges two sorted lists into one sorted list.", |
| "Write a Python function that returns the prime factors of an integer.", |
| "Implement a function that reverses words in a sentence in place.", |
| "Write a function that flattens an arbitrarily nested list of integers.", |
| ] |
|
|
|
|
| def _try_metrics(base_url: str) -> dict: |
| """Best-effort read of vLLM Prometheus spec-decode counters. Empty if absent.""" |
| out: dict = {} |
| try: |
| with urllib.request.urlopen(base_url.rstrip("/") + "/metrics", timeout=10) as r: |
| text = r.read().decode() |
| except Exception: |
| return out |
| for line in text.splitlines(): |
| if line.startswith("#"): |
| continue |
| |
| for key in ("spec_decode_num_accepted_tokens", |
| "spec_decode_num_draft_tokens", |
| "spec_decode_num_emitted_tokens"): |
| if key in line: |
| try: |
| out[key] = float(line.split()[-1]) |
| except ValueError: |
| pass |
| return out |
|
|
|
|
| def generate_one(base_url: str, model: str, prompt: str, max_tokens: int, |
| temperature: float) -> dict: |
| """One streamed completion. Returns timing + the generated text.""" |
| url = base_url.rstrip("/") + "/v1/completions" |
| payload = { |
| "model": model, |
| "prompt": prompt, |
| "max_tokens": max_tokens, |
| "temperature": temperature, |
| "stream": True, |
| } |
| data = json.dumps(payload).encode() |
| req = urllib.request.Request(url, data=data, |
| headers={"Content-Type": "application/json"}) |
| t0 = time.perf_counter() |
| ttft = None |
| n_tokens = 0 |
| chunks = [] |
| with urllib.request.urlopen(req, timeout=600) as r: |
| for raw in r: |
| line = raw.decode().strip() |
| if not line or not line.startswith("data:"): |
| continue |
| body = line[len("data:"):].strip() |
| if body == "[DONE]": |
| break |
| obj = json.loads(body) |
| piece = obj.get("choices", [{}])[0].get("text", "") |
| if piece: |
| if ttft is None: |
| ttft = time.perf_counter() - t0 |
| n_tokens += 1 |
| chunks.append(piece) |
| total = time.perf_counter() - t0 |
| decode_time = max(total - (ttft or 0.0), 1e-9) |
| tps = (n_tokens - 1) / decode_time if n_tokens > 1 else 0.0 |
| return { |
| "ttft_s": ttft, |
| "total_s": total, |
| "new_tokens": n_tokens, |
| "tokens_per_s": tps, |
| "text": "".join(chunks), |
| } |
|
|
|
|
| def run_rollout_batch(base_url: str, model: str, prompts: list[str], |
| rollouts_per_example: int, max_tokens: int, |
| temperature: float, label: str) -> list[dict]: |
| """Replay the prompt set rollouts_per_example times — one RL rollout batch.""" |
| runs = [] |
| total = len(prompts) * rollouts_per_example |
| k = 0 |
| for r in range(rollouts_per_example): |
| for prompt in prompts: |
| k += 1 |
| res = generate_one(base_url, model, prompt, max_tokens, temperature) |
| runs.append(res) |
| print(f" [{label}] rollout {k}/{total} " |
| f"tps={res['tokens_per_s']:.1f} ttft={res['ttft_s']:.3f}s") |
| return runs |
|
|
|
|
| def estimate_tau(before: dict, after: dict) -> float | None: |
| """tau from vLLM /metrics deltas. None if counters are unavailable. |
| |
| Committed tokens per target pass = accepted + 1 bonus per pass; the number of |
| passes ~= draft_tokens / gamma. Best-effort, exactly as measure.py does it. |
| """ |
| acc = after.get("spec_decode_num_accepted_tokens", 0) - before.get("spec_decode_num_accepted_tokens", 0) |
| draft = after.get("spec_decode_num_draft_tokens", 0) - before.get("spec_decode_num_draft_tokens", 0) |
| if draft > 0: |
| passes = draft / NUM_SPECULATIVE_TOKENS |
| if passes > 0: |
| committed = acc + passes |
| return committed / passes |
| return None |
|
|
|
|
| def assert_parity(base_url: str, model: str, prompts: list[str], max_tokens: int) -> dict: |
| """Run the GREEDY workload twice and assert byte-identical completions. |
| |
| On correct speculative decoding, greedy output is invariant — two runs MUST |
| match. A mismatch means the served config is nondeterministic (or broken), not |
| lossless. Raises AssertionError on any mismatch so a CI/demo run fails loudly. |
| """ |
| print("[parity] greedy run A ...") |
| a = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-A") |
| print("[parity] greedy run B ...") |
| b = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-B") |
| mismatches = sum(1 for x, y in zip(a, b) if x["text"] != y["text"]) |
| identical = len(a) - mismatches |
| result = { |
| "parity_pairs": len(a), |
| "identical": identical, |
| "mismatches": mismatches, |
| "lossless": mismatches == 0, |
| } |
| print(json.dumps(result, indent=2)) |
| assert mismatches == 0, ( |
| f"PARITY FAILED: {mismatches}/{len(a)} greedy completions differed across " |
| f"two runs of the same endpoint — output is NOT deterministic/lossless." |
| ) |
| print("[parity] PASS — greedy output is byte-identical across runs (lossless).") |
| return result |
|
|
|
|
| def main() -> None: |
| p = argparse.ArgumentParser( |
| description="Rollout-throughput benchmark (completions/sec, tokens/sec, TTFT, " |
| "acceptance length tau, projected $/run) against an OpenAI-compatible " |
| "endpoint. Measures live; never fabricates.", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| p.add_argument("--base-url", default="http://localhost:8000", |
| help="OpenAI-compatible endpoint root (vLLM serves /v1 and /metrics under it).") |
| p.add_argument("--model", default="laguna", |
| help="Served model name/id (serve_vllm.py registers the alias 'laguna').") |
| p.add_argument("--label", default="dflash", |
| help="Tag for the output: baseline | dflash. Just labels the JSON.") |
| p.add_argument("--prompts", type=int, default=len(PROMPTS), |
| help="How many of the built-in prompts to use (1..%d)." % len(PROMPTS)) |
| p.add_argument("--rollouts-per-example", type=int, default=8, |
| help="Completions sampled per prompt — mirrors the RL config's group size.") |
| p.add_argument("--max-tokens", type=int, default=512, |
| help="Max new tokens per completion. Match the RL sampling cap for honest $/run.") |
| p.add_argument("--temperature", type=float, default=0.0, |
| help="0.0 = greedy/deterministic (the lossless-comparable workload). " |
| "Keep 0 for the A/B so baseline and DFlash do identical work.") |
| p.add_argument("--hourly-rate", type=float, default=None, |
| help="GPU $/hour. If set, projects rollout-batch cost and (with --baseline-tps) savings.") |
| p.add_argument("--baseline-tps", type=float, default=None, |
| help="Baseline tokens/sec from a prior --label baseline run. Lets this run project " |
| "the $ SAVED vs baseline for the same rollout workload.") |
| p.add_argument("--assert-parity", action="store_true", |
| help="Run the greedy workload twice and assert byte-identical output (lossless check). " |
| "Exits nonzero on mismatch. Skips the throughput batch.") |
| p.add_argument("--out", default=None, help="Write JSON summary here (e.g. results/rollout_dflash.json).") |
| args = p.parse_args() |
|
|
| prompts = PROMPTS[:max(1, min(args.prompts, len(PROMPTS)))] |
|
|
| if args.assert_parity: |
| result = assert_parity(args.base_url, args.model, prompts, args.max_tokens) |
| if args.out: |
| os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) |
| with open(args.out, "w") as f: |
| json.dump({"label": args.label, "parity": result}, f, indent=2) |
| print(f"[rollout_bench] wrote {args.out}") |
| return |
|
|
| before = _try_metrics(args.base_url) |
| t_start = time.perf_counter() |
| runs = run_rollout_batch(args.base_url, args.model, prompts, |
| args.rollouts_per_example, args.max_tokens, |
| args.temperature, args.label) |
| wall_s = time.perf_counter() - t_start |
| after = _try_metrics(args.base_url) |
|
|
| tau = estimate_tau(before, after) |
| total_tokens = sum(r["new_tokens"] for r in runs) |
| n_rollouts = len(runs) |
| completions_per_s = n_rollouts / wall_s if wall_s > 0 else 0.0 |
| tokens_per_s_aggregate = total_tokens / wall_s if wall_s > 0 else 0.0 |
|
|
| summary = { |
| "label": args.label, |
| "model": args.model, |
| "base_url": args.base_url, |
| "prompts": len(prompts), |
| "rollouts_per_example": args.rollouts_per_example, |
| "n_rollouts": n_rollouts, |
| "max_tokens": args.max_tokens, |
| "temperature": args.temperature, |
| "wall_s": wall_s, |
| "completions_per_s": completions_per_s, |
| "total_new_tokens": total_tokens, |
| "tokens_per_s_aggregate": tokens_per_s_aggregate, |
| "tokens_per_s_mean_per_rollout": mean(r["tokens_per_s"] for r in runs), |
| "ttft_s_mean": mean(r["ttft_s"] for r in runs if r["ttft_s"] is not None), |
| "acceptance_length_tau": tau, |
| "spec_metrics_before": before, |
| "spec_metrics_after": after, |
| } |
|
|
| |
| |
| |
| |
| if args.hourly_rate is not None: |
| batch_cost = (wall_s / 3600.0) * args.hourly_rate |
| cost = {"hourly_rate": args.hourly_rate, "batch_cost_usd": batch_cost} |
| if args.baseline_tps and args.baseline_tps > 0 and total_tokens > 0: |
| baseline_wall_s = total_tokens / args.baseline_tps |
| baseline_cost = (baseline_wall_s / 3600.0) * args.hourly_rate |
| cost.update({ |
| "baseline_tps_reference": args.baseline_tps, |
| "projected_baseline_wall_s": baseline_wall_s, |
| "projected_baseline_cost_usd": baseline_cost, |
| "projected_savings_usd": baseline_cost - batch_cost, |
| "speedup_x": (args.baseline_tps and tokens_per_s_aggregate / args.baseline_tps) or None, |
| }) |
| summary["cost_projection"] = cost |
|
|
| print(json.dumps(summary, indent=2)) |
| if args.out: |
| os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) |
| |
| with open(args.out, "w") as f: |
| json.dump({**summary, "runs": runs}, f, indent=2) |
| print(f"[rollout_bench] wrote {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|