#!/usr/bin/env python3 """ rollout_bench.py — the COMBINED-THESIS benchmark. It measures the same endpoint that verifiers points its RL rollouts at (see configs/endpoints.toml), but frames the numbers the way an RL post-training run cares about: rollout throughput (completions/sec, tokens/sec) <- THE WIN TTFT (time to first token) <- ~unchanged with DFlash acceptance length tau <- WHY it's faster projected $/run saved <- WHY it's CHEAPER The thesis is "lossless DFlash speculative decoding makes RL post-training cheaper." RL spends most of its wall-clock generating rollouts, so a faster rollout endpoint — at IDENTICAL greedy output — buys the same reward curve for fewer GPU-hours. This script measures that, live, against whatever is serving on --base-url. It is a sibling of measure.py and reuses the same conventions: stdlib urllib only, streaming /v1/completions, greedy decode, best-effort read of vLLM /metrics. The ONE design rule: baseline vs DFlash is a one-flag swap on the SERVER (serve_vllm.py --mode), never a change here — so the same command produces both halves of the A/B. Workload: an RL "rollout batch" = a fixed prompt set, replayed identically, with --rollouts-per-example completions per prompt. The workload is deterministic (temperature 0 by default) so the BASELINE and DFLASH runs do identical work and the only thing that moves is speed. acceptance length tau: tau = mean tokens committed per target forward pass. With gamma=7 it ranges from 1 (all drafts rejected, +1 bonus) to 8 (all accepted + bonus). tau is NOT published in any Laguna/DFlash primary source — the model card gives per-position acceptance rates only (position-1 ~70.7%, decaying to ~2% at position-7). So we MEASURE it here from vLLM /metrics deltas. Expect roughly 2-3; never quote a published figure. None is printed if /metrics is unavailable — read it off the server's /metrics by hand then. VERIFY the exact metric names at onboarding. Losslessness: --assert-parity runs the deterministic (greedy) workload TWICE against the same endpoint and asserts byte-identical completions. On a correct speculative-decoding implementation greedy output is invariant, so two runs must match. (The baseline-vs-DFlash cross-server parity check lives in evals/humaneval_subset.py --parity; this in-run check guards against nondeterminism in the served config.) This does NOT fabricate anything. Every number comes from live HTTP calls. If the endpoint is down you get an error, not a made-up result. Usage: # measure a DFlash run and project savings at $3.50/GPU-hour python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ --label dflash --prompts 8 --rollouts-per-example 8 --max-tokens 512 \\ --hourly-rate 3.50 --out results/rollout_dflash.json # measure the baseline (re-serve with serve_vllm.py --mode baseline first) python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ --label baseline --hourly-rate 3.50 --out results/rollout_baseline.json # prove losslessness: two greedy runs against the same endpoint must be identical python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\ --label dflash --assert-parity Requires only the stdlib (urllib), so no extra venue deps. """ from __future__ import annotations import argparse import json import os import time import urllib.request from statistics import mean # Draft length gamma, per the DFlash model card. Used only to estimate the number # of target verification passes when turning /metrics counters into tau. NUM_SPECULATIVE_TOKENS = 7 # The fixed rollout prompt set. Coding-style, matching the DFlash card's domain # and measure.py's set, so tau and tokens/sec are comparable across the harness. PROMPTS = [ "Write a Python function that returns the nth Fibonacci number iteratively.", "Implement binary search over a sorted list in Python. Return the index or -1.", "Write a function to check if a string is a palindrome, ignoring case and spaces.", "Implement quicksort in Python.", "Write a function that merges two sorted lists into one sorted list.", "Write a Python function that returns the prime factors of an integer.", "Implement a function that reverses words in a sentence in place.", "Write a function that flattens an arbitrarily nested list of integers.", ] def _try_metrics(base_url: str) -> dict: """Best-effort read of vLLM Prometheus spec-decode counters. Empty if absent.""" out: dict = {} try: with urllib.request.urlopen(base_url.rstrip("/") + "/metrics", timeout=10) as r: text = r.read().decode() except Exception: return out for line in text.splitlines(): if line.startswith("#"): continue # VERIFY metric names at onboarding; these are the common vLLM ones. for key in ("spec_decode_num_accepted_tokens", "spec_decode_num_draft_tokens", "spec_decode_num_emitted_tokens"): if key in line: try: out[key] = float(line.split()[-1]) except ValueError: pass return out def generate_one(base_url: str, model: str, prompt: str, max_tokens: int, temperature: float) -> dict: """One streamed completion. Returns timing + the generated text.""" url = base_url.rstrip("/") + "/v1/completions" payload = { "model": model, "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, # 0.0 => greedy => deterministic => lossless-comparable "stream": True, } data = json.dumps(payload).encode() req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}) t0 = time.perf_counter() ttft = None n_tokens = 0 chunks = [] with urllib.request.urlopen(req, timeout=600) as r: for raw in r: line = raw.decode().strip() if not line or not line.startswith("data:"): continue body = line[len("data:"):].strip() if body == "[DONE]": break obj = json.loads(body) piece = obj.get("choices", [{}])[0].get("text", "") if piece: if ttft is None: ttft = time.perf_counter() - t0 n_tokens += 1 chunks.append(piece) total = time.perf_counter() - t0 decode_time = max(total - (ttft or 0.0), 1e-9) tps = (n_tokens - 1) / decode_time if n_tokens > 1 else 0.0 return { "ttft_s": ttft, "total_s": total, "new_tokens": n_tokens, "tokens_per_s": tps, "text": "".join(chunks), } def run_rollout_batch(base_url: str, model: str, prompts: list[str], rollouts_per_example: int, max_tokens: int, temperature: float, label: str) -> list[dict]: """Replay the prompt set rollouts_per_example times — one RL rollout batch.""" runs = [] total = len(prompts) * rollouts_per_example k = 0 for r in range(rollouts_per_example): for prompt in prompts: k += 1 res = generate_one(base_url, model, prompt, max_tokens, temperature) runs.append(res) print(f" [{label}] rollout {k}/{total} " f"tps={res['tokens_per_s']:.1f} ttft={res['ttft_s']:.3f}s") return runs def estimate_tau(before: dict, after: dict) -> float | None: """tau from vLLM /metrics deltas. None if counters are unavailable. Committed tokens per target pass = accepted + 1 bonus per pass; the number of passes ~= draft_tokens / gamma. Best-effort, exactly as measure.py does it. """ acc = after.get("spec_decode_num_accepted_tokens", 0) - before.get("spec_decode_num_accepted_tokens", 0) draft = after.get("spec_decode_num_draft_tokens", 0) - before.get("spec_decode_num_draft_tokens", 0) if draft > 0: passes = draft / NUM_SPECULATIVE_TOKENS if passes > 0: committed = acc + passes # +1 bonus token per verification pass return committed / passes return None def assert_parity(base_url: str, model: str, prompts: list[str], max_tokens: int) -> dict: """Run the GREEDY workload twice and assert byte-identical completions. On correct speculative decoding, greedy output is invariant — two runs MUST match. A mismatch means the served config is nondeterministic (or broken), not lossless. Raises AssertionError on any mismatch so a CI/demo run fails loudly. """ print("[parity] greedy run A ...") a = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-A") print("[parity] greedy run B ...") b = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-B") mismatches = sum(1 for x, y in zip(a, b) if x["text"] != y["text"]) identical = len(a) - mismatches result = { "parity_pairs": len(a), "identical": identical, "mismatches": mismatches, "lossless": mismatches == 0, } print(json.dumps(result, indent=2)) assert mismatches == 0, ( f"PARITY FAILED: {mismatches}/{len(a)} greedy completions differed across " f"two runs of the same endpoint — output is NOT deterministic/lossless." ) print("[parity] PASS — greedy output is byte-identical across runs (lossless).") return result def main() -> None: p = argparse.ArgumentParser( description="Rollout-throughput benchmark (completions/sec, tokens/sec, TTFT, " "acceptance length tau, projected $/run) against an OpenAI-compatible " "endpoint. Measures live; never fabricates.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--base-url", default="http://localhost:8000", help="OpenAI-compatible endpoint root (vLLM serves /v1 and /metrics under it).") p.add_argument("--model", default="laguna", help="Served model name/id (serve_vllm.py registers the alias 'laguna').") p.add_argument("--label", default="dflash", help="Tag for the output: baseline | dflash. Just labels the JSON.") p.add_argument("--prompts", type=int, default=len(PROMPTS), help="How many of the built-in prompts to use (1..%d)." % len(PROMPTS)) p.add_argument("--rollouts-per-example", type=int, default=8, help="Completions sampled per prompt — mirrors the RL config's group size.") p.add_argument("--max-tokens", type=int, default=512, help="Max new tokens per completion. Match the RL sampling cap for honest $/run.") p.add_argument("--temperature", type=float, default=0.0, help="0.0 = greedy/deterministic (the lossless-comparable workload). " "Keep 0 for the A/B so baseline and DFlash do identical work.") p.add_argument("--hourly-rate", type=float, default=None, help="GPU $/hour. If set, projects rollout-batch cost and (with --baseline-tps) savings.") p.add_argument("--baseline-tps", type=float, default=None, help="Baseline tokens/sec from a prior --label baseline run. Lets this run project " "the $ SAVED vs baseline for the same rollout workload.") p.add_argument("--assert-parity", action="store_true", help="Run the greedy workload twice and assert byte-identical output (lossless check). " "Exits nonzero on mismatch. Skips the throughput batch.") p.add_argument("--out", default=None, help="Write JSON summary here (e.g. results/rollout_dflash.json).") args = p.parse_args() prompts = PROMPTS[:max(1, min(args.prompts, len(PROMPTS)))] if args.assert_parity: result = assert_parity(args.base_url, args.model, prompts, args.max_tokens) if args.out: os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) with open(args.out, "w") as f: json.dump({"label": args.label, "parity": result}, f, indent=2) print(f"[rollout_bench] wrote {args.out}") return before = _try_metrics(args.base_url) t_start = time.perf_counter() runs = run_rollout_batch(args.base_url, args.model, prompts, args.rollouts_per_example, args.max_tokens, args.temperature, args.label) wall_s = time.perf_counter() - t_start after = _try_metrics(args.base_url) tau = estimate_tau(before, after) total_tokens = sum(r["new_tokens"] for r in runs) n_rollouts = len(runs) completions_per_s = n_rollouts / wall_s if wall_s > 0 else 0.0 tokens_per_s_aggregate = total_tokens / wall_s if wall_s > 0 else 0.0 summary = { "label": args.label, "model": args.model, "base_url": args.base_url, "prompts": len(prompts), "rollouts_per_example": args.rollouts_per_example, "n_rollouts": n_rollouts, "max_tokens": args.max_tokens, "temperature": args.temperature, "wall_s": wall_s, "completions_per_s": completions_per_s, # rollout throughput — the headline "total_new_tokens": total_tokens, "tokens_per_s_aggregate": tokens_per_s_aggregate, "tokens_per_s_mean_per_rollout": mean(r["tokens_per_s"] for r in runs), "ttft_s_mean": mean(r["ttft_s"] for r in runs if r["ttft_s"] is not None), "acceptance_length_tau": tau, # None if /metrics absent — read it off /metrics by hand then "spec_metrics_before": before, "spec_metrics_after": after, } # ---- projected $/run ------------------------------------------------- # Cost of THIS rollout batch at the given GPU price. If a baseline tokens/sec # is supplied, also project what the SAME workload would have cost at baseline # speed, and the savings — the dollars-and-cents form of the thesis. if args.hourly_rate is not None: batch_cost = (wall_s / 3600.0) * args.hourly_rate cost = {"hourly_rate": args.hourly_rate, "batch_cost_usd": batch_cost} if args.baseline_tps and args.baseline_tps > 0 and total_tokens > 0: baseline_wall_s = total_tokens / args.baseline_tps baseline_cost = (baseline_wall_s / 3600.0) * args.hourly_rate cost.update({ "baseline_tps_reference": args.baseline_tps, "projected_baseline_wall_s": baseline_wall_s, "projected_baseline_cost_usd": baseline_cost, "projected_savings_usd": baseline_cost - batch_cost, "speedup_x": (args.baseline_tps and tokens_per_s_aggregate / args.baseline_tps) or None, }) summary["cost_projection"] = cost print(json.dumps(summary, indent=2)) if args.out: os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) # Persist per-rollout detail alongside the summary for later inspection. with open(args.out, "w") as f: json.dump({**summary, "runs": runs}, f, indent=2) print(f"[rollout_bench] wrote {args.out}") if __name__ == "__main__": main()