qwen3-coder-next / bench.py
gdubicki's picture
Add bench.py performance benchmark script
d8bbe98 verified
#!/usr/bin/env python3
"""
bench.py — performance benchmarks for the vLLM server running Qwen3-Coder-Next-NVFP4-GB10.
Measures:
- Time to first token (TTFT) via streaming
- Decode throughput (tok/s)
- Prefill throughput (prompt tok/s)
- Latency across prompt lengths: short / medium / long / max
- Concurrent request throughput (1, 4, 8, 16 parallel requests)
- Reasoning ON vs OFF overhead
Usage:
python3 bench.py
python3 bench.py --host 192.168.1.50
python3 bench.py --host localhost --port 8000 --runs 3
"""
import argparse
import json
import statistics
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Optional
import urllib.request
import urllib.error
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--runs", type=int, default=3, help="Runs per scenario (default: 3)")
parser.add_argument("--no-color", action="store_true")
args = parser.parse_args()
BASE_URL = f"http://{args.host}:{args.port}/v1"
if args.no_color or not sys.stdout.isatty():
GREEN = RED = YELLOW = CYAN = BOLD = NC = ""
else:
GREEN = "\033[0;32m"
RED = "\033[0;31m"
YELLOW = "\033[0;33m"
CYAN = "\033[0;36m"
BOLD = "\033[1m"
NC = "\033[0m"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def get_model_id() -> str:
req = urllib.request.Request(f"{BASE_URL}/models")
with urllib.request.urlopen(req, timeout=10) as r:
data = json.loads(r.read())
return data["data"][0]["id"]
def chat_stream(model: str, messages: list, max_tokens: int, enable_thinking: bool) -> tuple[float, float, int, int]:
"""
Send a streaming chat completion request.
Returns: (ttft_s, total_s, prompt_tokens, completion_tokens)
"""
payload = json.dumps({
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.1,
"stream": True,
"stream_options": {"include_usage": True},
"chat_template_kwargs": {"enable_thinking": enable_thinking},
}).encode()
req = urllib.request.Request(
f"{BASE_URL}/chat/completions",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
ttft = None
t0 = time.perf_counter()
prompt_tokens = 0
completion_tokens = 0
with urllib.request.urlopen(req, timeout=300) as resp:
for raw_line in resp:
line = raw_line.decode().strip()
if not line.startswith("data:"):
continue
chunk = line[5:].strip()
if chunk == "[DONE]":
break
try:
obj = json.loads(chunk)
except json.JSONDecodeError:
continue
# First token
if ttft is None:
choices = obj.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content") or delta.get("reasoning_content")
if content:
ttft = time.perf_counter() - t0
# Usage (last chunk)
usage = obj.get("usage")
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total = time.perf_counter() - t0
if ttft is None:
ttft = total
return ttft, total, prompt_tokens, completion_tokens
@dataclass
class Result:
name: str
ttft_ms: list[float] = field(default_factory=list)
decode_tps: list[float] = field(default_factory=list)
prefill_tps: list[float] = field(default_factory=list)
total_s: list[float] = field(default_factory=list)
prompt_tokens: int = 0
completion_tokens: int = 0
def run_scenario(name: str, model: str, messages: list, max_tokens: int,
enable_thinking: bool = False, runs: int = 3) -> Result:
res = Result(name=name)
print(f" {CYAN}{name}{NC}", end="", flush=True)
for i in range(runs):
try:
ttft, total, pt, ct = chat_stream(model, messages, max_tokens, enable_thinking)
decode_time = total - ttft
res.ttft_ms.append(ttft * 1000)
res.decode_tps.append(ct / decode_time if decode_time > 0.01 else 0)
res.prefill_tps.append(pt / ttft if ttft > 0.01 else 0)
res.total_s.append(total)
res.prompt_tokens = pt
res.completion_tokens = ct
print(f" {GREEN}·{NC}", end="", flush=True)
except Exception as e:
print(f" {RED}{NC}", end="", flush=True)
print()
return res
def print_result(res: Result):
if not res.ttft_ms:
print(f" {RED}all runs failed{NC}")
return
ttft_med = statistics.median(res.ttft_ms)
dtps_med = statistics.median(res.decode_tps)
ptps_med = statistics.median(res.prefill_tps)
total_med = statistics.median(res.total_s)
print(f" prompt tokens : {res.prompt_tokens}")
print(f" completion tok : {res.completion_tokens}")
print(f" TTFT (median) : {BOLD}{ttft_med:.0f} ms{NC}")
print(f" decode (median) : {BOLD}{dtps_med:.1f} tok/s{NC}")
print(f" prefill (median): {ptps_med:.0f} tok/s")
print(f" total (median) : {total_med:.1f} s")
def run_concurrent(model: str, messages: list, max_tokens: int, concurrency: int,
enable_thinking: bool = False) -> tuple[float, float]:
"""Fire `concurrency` requests simultaneously, return (wall_s, aggregate_tps)."""
results = [None] * concurrency
errors = [None] * concurrency
def worker(idx):
try:
_, total, pt, ct = chat_stream(model, messages, max_tokens, enable_thinking)
results[idx] = (total, ct)
except Exception as e:
errors[idx] = e
threads = [threading.Thread(target=worker, args=(i,)) for i in range(concurrency)]
t0 = time.perf_counter()
for t in threads:
t.start()
for t in threads:
t.join()
wall = time.perf_counter() - t0
total_tokens = sum(r[1] for r in results if r)
agg_tps = total_tokens / wall if wall > 0 else 0
return wall, agg_tps
# ---------------------------------------------------------------------------
# Prompts
# ---------------------------------------------------------------------------
SHORT_PROMPT = "Write a Python one-liner that reverses a string."
MEDIUM_PROMPT = (
"Write a Python class implementing a generic LRU cache with O(1) get and put. "
"Use OrderedDict. Include docstrings, type annotations, and a short usage example."
)
LONG_PROMPT = (
"You are a senior software engineer. Review the following Python code and provide "
"a detailed analysis covering: correctness, edge cases, performance, readability, "
"and security concerns. Suggest concrete improvements with code examples.\n\n"
"```python\n"
+ "\n".join([
"import subprocess, os, json",
"from flask import Flask, request",
"",
"app = Flask(__name__)",
"",
"def run_query(user_input):",
" result = subprocess.run(",
" f'mysql -u root -ppassword mydb -e \"{user_input}\"',",
" shell=True, capture_output=True, text=True",
" )",
" return result.stdout",
"",
"@app.route('/query')",
"def query():",
" data = request.args.get('q', '')",
" output = run_query(data)",
" return json.dumps({'result': output, 'debug': os.environ})",
"",
"if __name__ == '__main__':",
" app.run(debug=True, host='0.0.0.0')",
])
+ "\n```"
)
# ~2000 token prompt via repeated context
CONTEXT_PROMPT = (
"You are given the following context about a distributed system architecture. "
"After reading it carefully, answer the questions at the end.\n\n"
+ ("Context: " + "A microservices-based e-commerce platform consists of the following services: "
"UserService (authentication, profiles), ProductService (catalog, search), "
"OrderService (cart, checkout, order management), PaymentService (Stripe integration), "
"NotificationService (email/SMS), and AnalyticsService (event tracking). "
"All services communicate via gRPC internally and expose REST APIs externally. "
"A Redis cluster handles session data and caching. PostgreSQL with read replicas "
"serves as the primary database. Kafka handles async event streaming between services. "
"Kubernetes on AWS EKS manages deployment with HPA for auto-scaling. "
"A global CDN sits in front of the API gateway. ") * 12
+ "\n\nQuestions:\n"
"1. What are the main single points of failure in this architecture?\n"
"2. How would you handle a PaymentService outage gracefully?\n"
"3. What observability stack would you recommend and why?\n"
"4. Suggest a strategy for zero-downtime database migrations.\n"
)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print(f"\n{BOLD}{'='*60}{NC}")
print(f"{BOLD} vLLM Performance Benchmark{NC}")
print(f"{BOLD} {BASE_URL}{NC}")
print(f"{BOLD}{'='*60}{NC}\n")
# Check server
try:
model = get_model_id()
print(f"{GREEN}[OK]{NC} Server up. Model: {model}\n")
except Exception as e:
print(f"{RED}[FAIL]{NC} Cannot reach server: {e}")
sys.exit(1)
runs = args.runs
results = []
# -----------------------------------------------------------------------
# 1. Latency across prompt lengths
# -----------------------------------------------------------------------
print(f"{BOLD}1. Latency across prompt lengths (reasoning OFF){NC}")
print(f" ({runs} runs each, streaming, median reported)\n")
for name, prompt, max_tok in [
("short (~10 prompt tok, 200 output)", SHORT_PROMPT, 200),
("medium (~80 prompt tok, 500 output)", MEDIUM_PROMPT, 500),
("long (~400 prompt tok, 800 output)", LONG_PROMPT, 800),
("ctx (~2K prompt tok, 600 output)", CONTEXT_PROMPT, 600),
]:
messages = [{"role": "user", "content": prompt}]
res = run_scenario(name, model, messages, max_tok, enable_thinking=False, runs=runs)
print_result(res)
results.append(res)
print()
# -----------------------------------------------------------------------
# 2. Reasoning ON vs OFF
# -----------------------------------------------------------------------
print(f"{BOLD}2. Reasoning ON vs OFF (medium prompt, 800 output tokens){NC}\n")
messages = [{"role": "user", "content": MEDIUM_PROMPT}]
for label, thinking in [("reasoning OFF", False), ("reasoning ON ", True)]:
res = run_scenario(label, model, messages, 800, enable_thinking=thinking, runs=runs)
print_result(res)
print()
# -----------------------------------------------------------------------
# 3. Concurrent requests
# -----------------------------------------------------------------------
print(f"{BOLD}3. Concurrent requests throughput (short prompt, 300 output tok){NC}\n")
messages = [{"role": "user", "content": SHORT_PROMPT}]
print(f" {'concurrency':<14} {'wall_s':>8} {'agg tok/s':>12}")
print(f" {'-'*36}")
for c in [1, 2, 4, 8, 16]:
wall, agg = run_concurrent(model, messages, 300, c, enable_thinking=False)
print(f" {c:<14} {wall:>8.1f} {agg:>12.1f}")
print()
# -----------------------------------------------------------------------
# 4. Tool calling smoke
# -----------------------------------------------------------------------
print(f"{BOLD}4. Tool calling latency{NC}\n")
tool_messages = [{"role": "user", "content": "What is the weather in Warsaw? Use the get_weather tool."}]
tool_payload = json.dumps({
"model": model,
"messages": tool_messages,
"max_tokens": 200,
"temperature": 0.1,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
}
}
}],
"tool_choice": "auto",
"chat_template_kwargs": {"enable_thinking": False},
}).encode()
times = []
print(f" tool_call latency", end="", flush=True)
for _ in range(runs):
try:
req = urllib.request.Request(
f"{BASE_URL}/chat/completions",
data=tool_payload,
headers={"Content-Type": "application/json"},
method="POST",
)
t0 = time.perf_counter()
with urllib.request.urlopen(req, timeout=60) as resp:
data = json.loads(resp.read())
elapsed = time.perf_counter() - t0
tool_calls = data["choices"][0]["message"].get("tool_calls")
if tool_calls:
times.append(elapsed * 1000)
print(f" {GREEN}·{NC}", end="", flush=True)
else:
print(f" {YELLOW}?{NC}", end="", flush=True)
except Exception:
print(f" {RED}{NC}", end="", flush=True)
print()
if times:
print(f" latency (median): {BOLD}{statistics.median(times):.0f} ms{NC}")
print()
# -----------------------------------------------------------------------
# Summary
# -----------------------------------------------------------------------
print(f"{BOLD}{'='*60}{NC}")
print(f"{BOLD} Summary{NC}")
print(f"{BOLD}{'='*60}{NC}")
print(f" {'scenario':<40} {'TTFT ms':>8} {'tok/s':>8}")
print(f" {'-'*58}")
for res in results:
if res.ttft_ms:
print(f" {res.name:<40} {statistics.median(res.ttft_ms):>8.0f} {statistics.median(res.decode_tps):>8.1f}")
print()
if __name__ == "__main__":
main()