| from __future__ import annotations |
|
|
| import argparse |
| import time |
| import statistics |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import requests |
| import torch |
| from transformers import AutoTokenizer |
| from model import load_and_process_dataset |
|
|
| from sglang.srt.environ import envs |
| from sglang.srt.utils import get_device_sm, kill_process_tree |
| from sglang.test.test_utils import ( |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| find_available_port, |
| popen_launch_server, |
| ) |
|
|
| def _is_blackwell() -> bool: |
| if envs.IS_BLACKWELL.get(): |
| return True |
| return get_device_sm() >= 100 |
|
|
|
|
| def _flush_cache(base_url: str) -> None: |
| resp = requests.get(base_url + "/flush_cache", timeout=60) |
| resp.raise_for_status() |
|
|
|
|
| def _send_generate( |
| base_url: str, |
| prompt: str, |
| *, |
| max_new_tokens: int, |
| stop: list[str], |
| timeout_s: int, |
| ) -> dict: |
| sampling_params: dict = { |
| "temperature": 0.0, |
| "top_p": 1.0, |
| "top_k": 1, |
| "max_new_tokens": int(max_new_tokens), |
| } |
| if stop: |
| sampling_params["stop"] = stop |
| resp = requests.post( |
| base_url + "/generate", |
| json={ |
| "text": prompt, |
| "sampling_params": sampling_params, |
| }, |
| timeout=int(timeout_s), |
| ) |
| resp.raise_for_status() |
| return resp.json() |
|
|
|
|
| def _send_generate_batch( |
| base_url: str, |
| prompts: list[str], |
| *, |
| max_new_tokens: int, |
| stop: list[str], |
| timeout_s: int, |
| ) -> list[dict]: |
| if not prompts: |
| return [] |
| sampling_params: dict = { |
| "temperature": 0.0, |
| "top_p": 1.0, |
| "top_k": 1, |
| "max_new_tokens": int(max_new_tokens), |
| } |
| if stop: |
| sampling_params["stop"] = stop |
| resp = requests.post( |
| base_url + "/generate", |
| json={ |
| "text": prompts, |
| "sampling_params": sampling_params, |
| }, |
| timeout=int(timeout_s), |
| ) |
| resp.raise_for_status() |
| out = resp.json() |
| if not isinstance(out, list): |
| raise RuntimeError( |
| "Expected a list response for batched /generate, but got " |
| f"type={type(out).__name__}." |
| ) |
| return out |
|
|
|
|
| @dataclass(frozen=True) |
| class BenchMetrics: |
| latency_s: float |
| output_tokens: int |
| output_toks_per_s: float |
| spec_accept_length: Optional[float] |
| spec_verify_ct_sum: int |
|
|
|
|
| def _run_bench_requests( |
| base_url: str, |
| *, |
| prompts: list[str], |
| max_new_tokens: int, |
| concurrency: int, |
| batch_requests: bool, |
| stop: list[str], |
| timeout_s: int, |
| expect_dflash: bool, |
| ) -> BenchMetrics: |
| |
| bs = max(int(concurrency), 1) |
| if len(prompts) > bs: |
| warmup_prompts = prompts[:bs] |
| if batch_requests: |
| _send_generate_batch( |
| base_url, |
| warmup_prompts, |
| max_new_tokens=max_new_tokens, |
| stop=stop, |
| timeout_s=timeout_s, |
| ) |
| else: |
| with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: |
| futures = [ |
| pool.submit( |
| _send_generate, |
| base_url, |
| prompt, |
| max_new_tokens=max_new_tokens, |
| stop=stop, |
| timeout_s=timeout_s, |
| ) |
| for prompt in warmup_prompts |
| ] |
| for fut in as_completed(futures): |
| fut.result() |
|
|
| prompts = prompts[bs:] |
|
|
| start = time.perf_counter() |
| total_tokens = 0 |
| spec_verify_ct_sum = 0 |
| spec_accept_lengths: list[float] = [] |
|
|
| if batch_requests: |
| bs = max(int(concurrency), 1) |
| for start_idx in range(0, len(prompts), bs): |
| chunk_prompts = prompts[start_idx : start_idx + bs] |
| outs = _send_generate_batch( |
| base_url, |
| chunk_prompts, |
| max_new_tokens=max_new_tokens, |
| stop=stop, |
| timeout_s=timeout_s, |
| ) |
| if len(outs) != len(chunk_prompts): |
| raise RuntimeError( |
| "Batched /generate output length mismatch: " |
| f"got {len(outs)} outputs for {len(chunk_prompts)} prompts." |
| ) |
|
|
| for j, out in enumerate(outs): |
| meta = out.get("meta_info", {}) or {} |
| total_tokens += int(meta.get("completion_tokens", 0)) |
| spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) |
| if "spec_accept_length" in meta: |
| try: |
| spec_accept_lengths.append(float(meta["spec_accept_length"])) |
| except (TypeError, ValueError): |
| pass |
| else: |
| with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: |
| futures = { |
| pool.submit( |
| _send_generate, |
| base_url, |
| prompt, |
| max_new_tokens=max_new_tokens, |
| stop=stop, |
| timeout_s=timeout_s, |
| ): i |
| for i, prompt in enumerate(prompts) |
| } |
| for fut in as_completed(futures): |
| out = fut.result() |
| meta = out.get("meta_info", {}) or {} |
| total_tokens += int(meta.get("completion_tokens", 0)) |
| spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) |
| if "spec_accept_length" in meta: |
| try: |
| spec_accept_lengths.append(float(meta["spec_accept_length"])) |
| except (TypeError, ValueError): |
| pass |
|
|
| latency = time.perf_counter() - start |
| toks_per_s = total_tokens / max(latency, 1e-6) |
|
|
| if expect_dflash and spec_verify_ct_sum <= 0: |
| raise RuntimeError( |
| "DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses " |
| "(DFLASH may not have been enabled)." |
| ) |
|
|
| spec_accept_length = ( |
| float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None |
| ) |
|
|
| return BenchMetrics( |
| latency_s=float(latency), |
| output_tokens=int(total_tokens), |
| output_toks_per_s=float(toks_per_s), |
| spec_accept_length=spec_accept_length, |
| spec_verify_ct_sum=int(spec_verify_ct_sum), |
| ) |
|
|
|
|
| def _format_table( |
| *, |
| concurrencies: list[int], |
| values: dict[int, Optional[float]], |
| float_fmt: str, |
| ) -> str: |
| header = ["conc"] + [str(c) for c in concurrencies] |
| lines = [ |
| "| " + " | ".join(header) + " |", |
| "| " + " | ".join(["---"] * len(header)) + " |", |
| ] |
| row = ["value"] |
| for c in concurrencies: |
| v = values.get(c, None) |
| row.append("N/A" if v is None else format(v, float_fmt)) |
| lines.append("| " + " | ".join(row) + " |") |
| return "\n".join(lines) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--output-md", |
| type=str, |
| default=None, |
| help="Write a markdown report to this file (disabled by default).", |
| ) |
| parser.add_argument("--dataset-name", type=str, default="gsm8k") |
| parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B") |
| parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16") |
| parser.add_argument( |
| "--skip-baseline", |
| action="store_true", |
| help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.", |
| ) |
| parser.add_argument( |
| "--batch-requests", |
| action="store_true", |
| help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", |
| ) |
| parser.add_argument("--max-new-tokens", type=int, default=2048) |
| parser.add_argument("--timeout-s", type=int, default=3600) |
| parser.add_argument("--mem-fraction-static", type=float, default=0.75) |
| parser.add_argument("--disable-radix-cache", action="store_true") |
| parser.add_argument("--dtype", type=str, default="bfloat16") |
| parser.add_argument("--max-running-requests", type=int, default=64) |
| parser.add_argument( |
| "--tp-size", |
| type=int, |
| default=1, |
| help="Tensor parallel size (single value, no sweep).", |
| ) |
| parser.add_argument( |
| "--concurrencies", |
| type=str, |
| default="1,2,4,8,16,32", |
| help="Comma-separated list of client concurrency levels.", |
| ) |
| parser.add_argument( |
| "--questions-per-concurrency-base", |
| type=int, |
| default=128, |
| help="num_questions = base * concurrency (default matches the sweep plan).", |
| ) |
| parser.add_argument( |
| "--max-questions-per-config", |
| type=int, |
| default=1024, |
| help="Cap num_questions per (tp, concurrency) run (default: 1024).", |
| ) |
| parser.add_argument( |
| "--attention-backends", |
| type=str, |
| default="flashinfer,fa3,fa4", |
| help="Comma-separated list. Will auto-skip fa3 unless SM90 (Hopper), and fa4 unless SM100+ (Blackwell).", |
| ) |
| args = parser.parse_args() |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required for this sweep.") |
|
|
| concurrencies = [int(x) for x in args.concurrencies.split(",") if x.strip()] |
| concurrencies = [c for c in concurrencies if c >= 1] |
| if not concurrencies: |
| raise RuntimeError("No concurrencies specified.") |
|
|
| num_questions_by_conc = { |
| c: min(int(args.questions_per_concurrency_base) * int(c), int(args.max_questions_per_config)) |
| for c in concurrencies |
| } |
| max_questions = max(num_questions_by_conc.values()) |
| max_concurrency = max(concurrencies) |
|
|
| attention_backends = [s.strip() for s in args.attention_backends.split(",") if s.strip()] |
| is_blackwell = _is_blackwell() |
| device_sm = get_device_sm() |
| if device_sm != 90: |
| attention_backends = [b for b in attention_backends if b != "fa3"] |
| if device_sm < 100: |
| attention_backends = [b for b in attention_backends if b != "fa4"] |
| attention_backends = attention_backends or ["flashinfer"] |
|
|
| |
| print(f"Loading dataset: {args.dataset_name}...") |
| dataset = load_and_process_dataset(args.dataset_name) |
| required_questions = max_questions + max_concurrency |
| |
| if len(dataset) < required_questions: |
| print(f"Warning: Dataset has {len(dataset)} items, but need up to {required_questions}. Reusing items.") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.target_model) |
|
|
| prompts: list[str] = [] |
| |
| for i in range(max(len(dataset), required_questions)): |
| item = dataset[i % len(dataset)] |
| user_content = item["turns"][0] |
| |
| |
| prompt_text = tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_content}], |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
| prompts.append(prompt_text) |
| if len(prompts) >= required_questions: |
| break |
|
|
| |
| |
| baseline_toks: dict[tuple[str, int], Optional[float]] = {} |
| dflash_toks: dict[tuple[str, int], Optional[float]] = {} |
| dflash_accept_len: dict[tuple[str, int], Optional[float]] = {} |
| |
| tp = args.tp_size |
|
|
| for backend in attention_backends: |
| port_base = find_available_port(20000) |
|
|
| common_server_args: list[str] = [ |
| "--trust-remote-code", |
| "--attention-backend", |
| backend, |
| "--tp-size", |
| str(tp), |
| "--dtype", |
| str(args.dtype), |
| "--mem-fraction-static", |
| str(args.mem_fraction_static), |
| "--max-running-requests", |
| str(args.max_running_requests), |
| ] |
| common_server_args.extend( |
| ["--cuda-graph-bs", *[str(i) for i in range(1, 33)], "--cuda-graph-max-bs", "32"] |
| ) |
| if args.disable_radix_cache: |
| common_server_args.append("--disable-radix-cache") |
|
|
| if not args.skip_baseline: |
| print(f"\n=== backend={backend} tp={tp} (baseline) ===") |
| baseline_port = port_base |
| baseline_url = f"http://127.0.0.1:{baseline_port}" |
| baseline_proc = popen_launch_server( |
| args.target_model, |
| baseline_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=common_server_args, |
| ) |
| try: |
| |
| _send_generate( |
| baseline_url, |
| "Hello", |
| max_new_tokens=8, |
| stop=[], |
| timeout_s=min(int(args.timeout_s), 300), |
| ) |
|
|
| for conc in concurrencies: |
| n = num_questions_by_conc[conc] |
| _flush_cache(baseline_url) |
| print( |
| f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics." |
| ) |
| metrics = _run_bench_requests( |
| baseline_url, |
| prompts=prompts[: n + conc], |
| max_new_tokens=int(args.max_new_tokens), |
| concurrency=int(conc), |
| batch_requests=bool(args.batch_requests), |
| stop=[], |
| timeout_s=int(args.timeout_s), |
| expect_dflash=False, |
| ) |
| baseline_toks[(backend, conc)] = metrics.output_toks_per_s |
| print( |
| f"[baseline] conc={conc:>2} n={n:<4} " |
| f"toks/s={metrics.output_toks_per_s:,.2f} " |
| f"latency={metrics.latency_s:.1f}s " |
| ) |
| finally: |
| kill_process_tree(baseline_proc.pid) |
| try: |
| baseline_proc.wait(timeout=30) |
| except Exception: |
| pass |
|
|
| print(f"\n=== backend={backend} tp={tp} (DFLASH) ===") |
| dflash_port = find_available_port(port_base + 1) |
| dflash_url = f"http://127.0.0.1:{dflash_port}" |
| dflash_proc = popen_launch_server( |
| args.target_model, |
| dflash_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=[ |
| *common_server_args, |
| "--speculative-algorithm", |
| "DFLASH", |
| "--speculative-draft-model-path", |
| args.draft_model, |
| ], |
| ) |
| try: |
| _send_generate( |
| dflash_url, |
| "Hello", |
| max_new_tokens=8, |
| stop=[], |
| timeout_s=min(int(args.timeout_s), 300), |
| ) |
| for conc in concurrencies: |
| n = num_questions_by_conc[conc] |
| _flush_cache(dflash_url) |
| print( |
| f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics." |
| ) |
| metrics = _run_bench_requests( |
| dflash_url, |
| prompts=prompts[: n + conc], |
| max_new_tokens=int(args.max_new_tokens), |
| concurrency=int(conc), |
| batch_requests=bool(args.batch_requests), |
| stop=[], |
| timeout_s=int(args.timeout_s), |
| expect_dflash=True, |
| ) |
| dflash_toks[(backend, conc)] = metrics.output_toks_per_s |
| dflash_accept_len[(backend, conc)] = metrics.spec_accept_length |
| print( |
| f"[DFLASH] conc={conc:>2} n={n:<4} " |
| f"toks/s={metrics.output_toks_per_s:,.2f} " |
| f"latency={metrics.latency_s:.1f}s " |
| f"accept_len={metrics.spec_accept_length:.3f} " |
| f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" |
| ) |
| finally: |
| kill_process_tree(dflash_proc.pid) |
| try: |
| dflash_proc.wait(timeout=30) |
| except Exception: |
| pass |
|
|
| |
| md_lines: list[str] = [] |
| md_lines.append("# DFLASH Bench Report") |
| md_lines.append("") |
| md_lines.append("## Settings") |
| md_lines.append(f"- dataset: `{args.dataset_name}`") |
| md_lines.append(f"- target_model: `{args.target_model}`") |
| md_lines.append(f"- draft_model: `{args.draft_model}`") |
| md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`") |
| md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") |
| md_lines.append(f"- tp_size: `{tp}`") |
| md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") |
| md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`") |
| md_lines.append(f"- device_sm: `{device_sm}`") |
| md_lines.append(f"- is_blackwell: `{is_blackwell}`") |
| md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`") |
| md_lines.append("- drop_first_batch: `true`") |
| md_lines.append("") |
|
|
| for backend in attention_backends: |
| md_lines.append(f"## Backend: `{backend}`") |
| md_lines.append("") |
|
|
| baseline_values = { |
| c: baseline_toks.get((backend, c), None) for c in concurrencies |
| } |
| dflash_values = { |
| c: dflash_toks.get((backend, c), None) for c in concurrencies |
| } |
| speedup_values: dict[int, Optional[float]] = {} |
| for c in concurrencies: |
| b = baseline_values.get(c, None) |
| d = dflash_values.get(c, None) |
| speedup_values[c] = None if (b is None or d is None or b <= 0) else (d / b) |
|
|
| md_lines.append("### Baseline output tok/s") |
| md_lines.append( |
| _format_table( |
| concurrencies=concurrencies, |
| values=baseline_values, |
| float_fmt=",.2f", |
| ) |
| ) |
| md_lines.append("") |
| |
| md_lines.append("### DFLASH output tok/s") |
| md_lines.append( |
| _format_table( |
| concurrencies=concurrencies, |
| values=dflash_values, |
| float_fmt=",.2f", |
| ) |
| ) |
| md_lines.append("") |
|
|
| md_lines.append("### Speedup (DFLASH / baseline)") |
| md_lines.append( |
| _format_table( |
| concurrencies=concurrencies, |
| values=speedup_values, |
| float_fmt=".3f", |
| ) |
| ) |
| md_lines.append("") |
|
|
| md_lines.append("### DFLASH acceptance length") |
| md_lines.append( |
| _format_table( |
| concurrencies=concurrencies, |
| values={ |
| c: dflash_accept_len.get((backend, c), None) |
| for c in concurrencies |
| }, |
| float_fmt=".3f", |
| ) |
| ) |
| md_lines.append("") |
|
|
| if args.output_md: |
| with open(args.output_md, "w", encoding="utf-8") as f: |
| f.write("\n".join(md_lines)) |
| f.write("\n") |
| print(f"\nWrote markdown report to: {args.output_md}") |
| else: |
| print("\nMarkdown report disabled (pass --output-md to write one).") |
|
|
|
|
| if __name__ == "__main__": |
| main() |