lean-laguna / bench /rollout_bench.py
art87able's picture
Lean Laguna: lossless DFlash speculative decoding on Laguna XS.2 (harness, environment, results)
0a55ff6
#!/usr/bin/env python3
"""
rollout_bench.py — the COMBINED-THESIS benchmark. It measures the same endpoint
that verifiers points its RL rollouts at (see configs/endpoints.toml), but frames
the numbers the way an RL post-training run cares about:
rollout throughput (completions/sec, tokens/sec) <- THE WIN
TTFT (time to first token) <- ~unchanged with DFlash
acceptance length tau <- WHY it's faster
projected $/run saved <- WHY it's CHEAPER
The thesis is "lossless DFlash speculative decoding makes RL post-training
cheaper." RL spends most of its wall-clock generating rollouts, so a faster
rollout endpoint — at IDENTICAL greedy output — buys the same reward curve for
fewer GPU-hours. This script measures that, live, against whatever is serving on
--base-url. It is a sibling of measure.py and reuses the same conventions:
stdlib urllib only, streaming /v1/completions, greedy decode, best-effort read of
vLLM /metrics. The ONE design rule: baseline vs DFlash is a one-flag swap on the
SERVER (serve_vllm.py --mode), never a change here — so the same command produces
both halves of the A/B.
Workload: an RL "rollout batch" = a fixed prompt set, replayed identically, with
--rollouts-per-example completions per prompt. The workload is deterministic
(temperature 0 by default) so the BASELINE and DFLASH runs do identical work and
the only thing that moves is speed.
acceptance length tau:
tau = mean tokens committed per target forward pass. With gamma=7 it ranges
from 1 (all drafts rejected, +1 bonus) to 8 (all accepted + bonus). tau is NOT
published in any Laguna/DFlash primary source — the model card gives per-position
acceptance rates only (position-1 ~70.7%, decaying to ~2% at position-7). So we
MEASURE it here from vLLM /metrics deltas. Expect roughly 2-3; never quote a
published figure. None is printed if /metrics is unavailable — read it off the
server's /metrics by hand then. VERIFY the exact metric names at onboarding.
Losslessness:
--assert-parity runs the deterministic (greedy) workload TWICE against the same
endpoint and asserts byte-identical completions. On a correct speculative-decoding
implementation greedy output is invariant, so two runs must match. (The
baseline-vs-DFlash cross-server parity check lives in evals/humaneval_subset.py
--parity; this in-run check guards against nondeterminism in the served config.)
This does NOT fabricate anything. Every number comes from live HTTP calls. If the
endpoint is down you get an error, not a made-up result.
Usage:
# measure a DFlash run and project savings at $3.50/GPU-hour
python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\
--label dflash --prompts 8 --rollouts-per-example 8 --max-tokens 512 \\
--hourly-rate 3.50 --out results/rollout_dflash.json
# measure the baseline (re-serve with serve_vllm.py --mode baseline first)
python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\
--label baseline --hourly-rate 3.50 --out results/rollout_baseline.json
# prove losslessness: two greedy runs against the same endpoint must be identical
python bench/rollout_bench.py --base-url http://localhost:8000 --model laguna \\
--label dflash --assert-parity
Requires only the stdlib (urllib), so no extra venue deps.
"""
from __future__ import annotations
import argparse
import json
import os
import time
import urllib.request
from statistics import mean
# Draft length gamma, per the DFlash model card. Used only to estimate the number
# of target verification passes when turning /metrics counters into tau.
NUM_SPECULATIVE_TOKENS = 7
# The fixed rollout prompt set. Coding-style, matching the DFlash card's domain
# and measure.py's set, so tau and tokens/sec are comparable across the harness.
PROMPTS = [
"Write a Python function that returns the nth Fibonacci number iteratively.",
"Implement binary search over a sorted list in Python. Return the index or -1.",
"Write a function to check if a string is a palindrome, ignoring case and spaces.",
"Implement quicksort in Python.",
"Write a function that merges two sorted lists into one sorted list.",
"Write a Python function that returns the prime factors of an integer.",
"Implement a function that reverses words in a sentence in place.",
"Write a function that flattens an arbitrarily nested list of integers.",
]
def _try_metrics(base_url: str) -> dict:
"""Best-effort read of vLLM Prometheus spec-decode counters. Empty if absent."""
out: dict = {}
try:
with urllib.request.urlopen(base_url.rstrip("/") + "/metrics", timeout=10) as r:
text = r.read().decode()
except Exception:
return out
for line in text.splitlines():
if line.startswith("#"):
continue
# VERIFY metric names at onboarding; these are the common vLLM ones.
for key in ("spec_decode_num_accepted_tokens",
"spec_decode_num_draft_tokens",
"spec_decode_num_emitted_tokens"):
if key in line:
try:
out[key] = float(line.split()[-1])
except ValueError:
pass
return out
def generate_one(base_url: str, model: str, prompt: str, max_tokens: int,
temperature: float) -> dict:
"""One streamed completion. Returns timing + the generated text."""
url = base_url.rstrip("/") + "/v1/completions"
payload = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature, # 0.0 => greedy => deterministic => lossless-comparable
"stream": True,
}
data = json.dumps(payload).encode()
req = urllib.request.Request(url, data=data,
headers={"Content-Type": "application/json"})
t0 = time.perf_counter()
ttft = None
n_tokens = 0
chunks = []
with urllib.request.urlopen(req, timeout=600) as r:
for raw in r:
line = raw.decode().strip()
if not line or not line.startswith("data:"):
continue
body = line[len("data:"):].strip()
if body == "[DONE]":
break
obj = json.loads(body)
piece = obj.get("choices", [{}])[0].get("text", "")
if piece:
if ttft is None:
ttft = time.perf_counter() - t0
n_tokens += 1
chunks.append(piece)
total = time.perf_counter() - t0
decode_time = max(total - (ttft or 0.0), 1e-9)
tps = (n_tokens - 1) / decode_time if n_tokens > 1 else 0.0
return {
"ttft_s": ttft,
"total_s": total,
"new_tokens": n_tokens,
"tokens_per_s": tps,
"text": "".join(chunks),
}
def run_rollout_batch(base_url: str, model: str, prompts: list[str],
rollouts_per_example: int, max_tokens: int,
temperature: float, label: str) -> list[dict]:
"""Replay the prompt set rollouts_per_example times — one RL rollout batch."""
runs = []
total = len(prompts) * rollouts_per_example
k = 0
for r in range(rollouts_per_example):
for prompt in prompts:
k += 1
res = generate_one(base_url, model, prompt, max_tokens, temperature)
runs.append(res)
print(f" [{label}] rollout {k}/{total} "
f"tps={res['tokens_per_s']:.1f} ttft={res['ttft_s']:.3f}s")
return runs
def estimate_tau(before: dict, after: dict) -> float | None:
"""tau from vLLM /metrics deltas. None if counters are unavailable.
Committed tokens per target pass = accepted + 1 bonus per pass; the number of
passes ~= draft_tokens / gamma. Best-effort, exactly as measure.py does it.
"""
acc = after.get("spec_decode_num_accepted_tokens", 0) - before.get("spec_decode_num_accepted_tokens", 0)
draft = after.get("spec_decode_num_draft_tokens", 0) - before.get("spec_decode_num_draft_tokens", 0)
if draft > 0:
passes = draft / NUM_SPECULATIVE_TOKENS
if passes > 0:
committed = acc + passes # +1 bonus token per verification pass
return committed / passes
return None
def assert_parity(base_url: str, model: str, prompts: list[str], max_tokens: int) -> dict:
"""Run the GREEDY workload twice and assert byte-identical completions.
On correct speculative decoding, greedy output is invariant — two runs MUST
match. A mismatch means the served config is nondeterministic (or broken), not
lossless. Raises AssertionError on any mismatch so a CI/demo run fails loudly.
"""
print("[parity] greedy run A ...")
a = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-A")
print("[parity] greedy run B ...")
b = run_rollout_batch(base_url, model, prompts, 1, max_tokens, 0.0, "parity-B")
mismatches = sum(1 for x, y in zip(a, b) if x["text"] != y["text"])
identical = len(a) - mismatches
result = {
"parity_pairs": len(a),
"identical": identical,
"mismatches": mismatches,
"lossless": mismatches == 0,
}
print(json.dumps(result, indent=2))
assert mismatches == 0, (
f"PARITY FAILED: {mismatches}/{len(a)} greedy completions differed across "
f"two runs of the same endpoint — output is NOT deterministic/lossless."
)
print("[parity] PASS — greedy output is byte-identical across runs (lossless).")
return result
def main() -> None:
p = argparse.ArgumentParser(
description="Rollout-throughput benchmark (completions/sec, tokens/sec, TTFT, "
"acceptance length tau, projected $/run) against an OpenAI-compatible "
"endpoint. Measures live; never fabricates.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--base-url", default="http://localhost:8000",
help="OpenAI-compatible endpoint root (vLLM serves /v1 and /metrics under it).")
p.add_argument("--model", default="laguna",
help="Served model name/id (serve_vllm.py registers the alias 'laguna').")
p.add_argument("--label", default="dflash",
help="Tag for the output: baseline | dflash. Just labels the JSON.")
p.add_argument("--prompts", type=int, default=len(PROMPTS),
help="How many of the built-in prompts to use (1..%d)." % len(PROMPTS))
p.add_argument("--rollouts-per-example", type=int, default=8,
help="Completions sampled per prompt — mirrors the RL config's group size.")
p.add_argument("--max-tokens", type=int, default=512,
help="Max new tokens per completion. Match the RL sampling cap for honest $/run.")
p.add_argument("--temperature", type=float, default=0.0,
help="0.0 = greedy/deterministic (the lossless-comparable workload). "
"Keep 0 for the A/B so baseline and DFlash do identical work.")
p.add_argument("--hourly-rate", type=float, default=None,
help="GPU $/hour. If set, projects rollout-batch cost and (with --baseline-tps) savings.")
p.add_argument("--baseline-tps", type=float, default=None,
help="Baseline tokens/sec from a prior --label baseline run. Lets this run project "
"the $ SAVED vs baseline for the same rollout workload.")
p.add_argument("--assert-parity", action="store_true",
help="Run the greedy workload twice and assert byte-identical output (lossless check). "
"Exits nonzero on mismatch. Skips the throughput batch.")
p.add_argument("--out", default=None, help="Write JSON summary here (e.g. results/rollout_dflash.json).")
args = p.parse_args()
prompts = PROMPTS[:max(1, min(args.prompts, len(PROMPTS)))]
if args.assert_parity:
result = assert_parity(args.base_url, args.model, prompts, args.max_tokens)
if args.out:
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
with open(args.out, "w") as f:
json.dump({"label": args.label, "parity": result}, f, indent=2)
print(f"[rollout_bench] wrote {args.out}")
return
before = _try_metrics(args.base_url)
t_start = time.perf_counter()
runs = run_rollout_batch(args.base_url, args.model, prompts,
args.rollouts_per_example, args.max_tokens,
args.temperature, args.label)
wall_s = time.perf_counter() - t_start
after = _try_metrics(args.base_url)
tau = estimate_tau(before, after)
total_tokens = sum(r["new_tokens"] for r in runs)
n_rollouts = len(runs)
completions_per_s = n_rollouts / wall_s if wall_s > 0 else 0.0
tokens_per_s_aggregate = total_tokens / wall_s if wall_s > 0 else 0.0
summary = {
"label": args.label,
"model": args.model,
"base_url": args.base_url,
"prompts": len(prompts),
"rollouts_per_example": args.rollouts_per_example,
"n_rollouts": n_rollouts,
"max_tokens": args.max_tokens,
"temperature": args.temperature,
"wall_s": wall_s,
"completions_per_s": completions_per_s, # rollout throughput — the headline
"total_new_tokens": total_tokens,
"tokens_per_s_aggregate": tokens_per_s_aggregate,
"tokens_per_s_mean_per_rollout": mean(r["tokens_per_s"] for r in runs),
"ttft_s_mean": mean(r["ttft_s"] for r in runs if r["ttft_s"] is not None),
"acceptance_length_tau": tau, # None if /metrics absent — read it off /metrics by hand then
"spec_metrics_before": before,
"spec_metrics_after": after,
}
# ---- projected $/run -------------------------------------------------
# Cost of THIS rollout batch at the given GPU price. If a baseline tokens/sec
# is supplied, also project what the SAME workload would have cost at baseline
# speed, and the savings — the dollars-and-cents form of the thesis.
if args.hourly_rate is not None:
batch_cost = (wall_s / 3600.0) * args.hourly_rate
cost = {"hourly_rate": args.hourly_rate, "batch_cost_usd": batch_cost}
if args.baseline_tps and args.baseline_tps > 0 and total_tokens > 0:
baseline_wall_s = total_tokens / args.baseline_tps
baseline_cost = (baseline_wall_s / 3600.0) * args.hourly_rate
cost.update({
"baseline_tps_reference": args.baseline_tps,
"projected_baseline_wall_s": baseline_wall_s,
"projected_baseline_cost_usd": baseline_cost,
"projected_savings_usd": baseline_cost - batch_cost,
"speedup_x": (args.baseline_tps and tokens_per_s_aggregate / args.baseline_tps) or None,
})
summary["cost_projection"] = cost
print(json.dumps(summary, indent=2))
if args.out:
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
# Persist per-rollout detail alongside the summary for later inspection.
with open(args.out, "w") as f:
json.dump({**summary, "runs": runs}, f, indent=2)
print(f"[rollout_bench] wrote {args.out}")
if __name__ == "__main__":
main()