bitnet-ai / stress_test.py
Soumik-404's picture
first commit
3a74e13
Raw
History Blame Contribute Delete
10.7 kB
"""
BitNet b1.58 async stress test.
Fires N concurrent requests against the llama-server OpenAI-compatible API
and collects per-request timing, token throughput, and error stats.
Uses asyncio + aiohttp for non-blocking I/O.
Usage:
python stress_test.py [--url URL] [--requests N] [--concurrency C] [--warmup W]
"""
import sys
import time
import json
import asyncio
import statistics
import argparse
from datetime import datetime
import aiohttp
import colorama
colorama.init(autoreset=True)
C = colorama.Fore
S = colorama.Style
DIM = colorama.Style.DIM
PAYLOAD = {
"model": "bitnet",
"messages": [
{
"role": "user",
"content": (
"Explain the concept of neural networks in simple terms. "
"Include how they are trained and what makes them different "
"from traditional computer programs."
)
}
],
"max_tokens": 80,
"temperature": 0.7,
}
DESCRIPTION = "Medium chat completion (32 tok prompt, 80 tok generate)"
def log(msg, color=C.WHITE, bright=False):
style = S.BRIGHT if bright else ""
ts = datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(f"{DIM}{ts}{S.RESET_ALL} {style}{color}{msg}{S.RESET_ALL}")
# ---------------------------------------------------------------------------
# Per-request worker
# ---------------------------------------------------------------------------
async def do_request(session, req_id, url, timeout):
start = time.perf_counter()
try:
async with session.post(
url,
json=PAYLOAD,
timeout=aiohttp.ClientTimeout(total=timeout),
) as resp:
elapsed = time.perf_counter() - start
status = resp.status
body = await resp.json()
if status == 200 and "choices" in body:
usage = body.get("usage", {})
pt = usage.get("prompt_tokens", 0)
ct = usage.get("completion_tokens", 0)
tt = usage.get("total_tokens", 0)
tok_sec = ct / elapsed if elapsed > 0 else 0
content = body.get("choices", [{}])[0].get("message", {}).get("content", "")
ok = True
err = None
ctext = content[:55].replace("\n", " ")
log(
f"[#{req_id:03d}] {C.GREEN}OK{C.RESET} "
f"{elapsed:6.1f}s {ct:3d} tok {tok_sec:5.1f} tok/s "
f"\"{ctext}...\"",
)
else:
pt = ct = tt = 0
tok_sec = 0
ok = False
err = body.get("error", {}).get("message", f"HTTP {status}")
content = ""
log(f"[#{req_id:03d}] {C.RED}FAIL{C.RESET} {elapsed:6.1f}s {err}", C.RED)
except asyncio.TimeoutError:
elapsed = time.perf_counter() - start
ok = False
err = "TIMEOUT"
pt = ct = tt = 0
tok_sec = 0
content = ""
log(f"[#{req_id:03d}] {C.RED}TIMEOUT{C.RESET} {elapsed:6.1f}s", C.RED)
except Exception as e:
elapsed = time.perf_counter() - start
ok = False
err = str(e)[:70]
pt = ct = tt = 0
tok_sec = 0
content = ""
log(f"[#{req_id:03d}] {C.RED}ERROR{C.RESET} {elapsed:6.1f}s {err}", C.RED)
return {
"id": req_id,
"ok": ok,
"status": status if "status" in locals() else 0,
"elapsed": elapsed,
"prompt_tokens": pt,
"completion_tokens": ct,
"total_tokens": tt,
"tokens_per_sec": tok_sec,
"content_len": len(content),
"error": err,
}
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def print_audit(results, wall_time, url, n_req, concurrency, warmup):
ok_r = [r for r in results if r["ok"]]
fail_r = [r for r in results if not r["ok"]]
elapsed_ok = [r["elapsed"] for r in ok_r]
tok_sec_ok = [r["tokens_per_sec"] for r in ok_r]
ct_ok = [r["completion_tokens"] for r in ok_r]
pt_ok = [r["prompt_tokens"] for r in ok_r]
total_prompt = sum(pt_ok)
total_completion = sum(ct_ok)
p95 = sorted(elapsed_ok)[int(len(elapsed_ok) * 0.95)] if elapsed_ok else 0
# Determine if we have slot info
concurrency_note = f"{concurrency} concurrent (asyncio)"
print()
log("=" * 68, C.MAGENTA, bright=True)
log(" STRESS TEST AUDIT REPORT", C.MAGENTA, bright=True)
log("=" * 68, C.MAGENTA, bright=True)
print()
rows = [
("Target URL", url),
("Description", DESCRIPTION),
("Total requests", str(n_req)),
("Concurrency", concurrency_note),
("Warmup requests", str(warmup)),
("Wall clock", f"{wall_time:.1f}s"),
]
for k, v in rows:
log(f" {C.YELLOW}{k:<22}{S.RESET_ALL} {v}")
print()
log(f" {C.CYAN}{'RESULTS':<22}{S.RESET_ALL}", bright=True)
log(f" {'Succeeded':<22} {C.GREEN}{len(ok_r):>4}{S.RESET_ALL} ({len(ok_r)/max(len(results),1)*100:5.1f}%)", C.GREEN)
log(f" {'Failed':<22} {C.RED}{len(fail_r):>4}{S.RESET_ALL} ({len(fail_r)/max(len(results),1)*100:5.1f}%)", C.RED)
print()
log(f" {C.CYAN}{'LATENCY (seconds)':<22}{S.RESET_ALL}", bright=True)
if elapsed_ok:
log(f" {'Average':<22} {statistics.mean(elapsed_ok):>6.2f}")
log(f" {'Median':<22} {statistics.median(elapsed_ok):>6.2f}")
log(f" {'Min':<22} {C.GREEN}{min(elapsed_ok):>6.2f}{S.RESET_ALL}")
log(f" {'Max':<22} {C.RED}{max(elapsed_ok):>6.2f}{S.RESET_ALL}")
log(f" {'P95':<22} {p95:>6.2f}")
print()
log(f" {C.CYAN}{'TOKEN THROUGHPUT (tok/s)':<22}{S.RESET_ALL}", bright=True)
if tok_sec_ok:
log(f" {'Average':<22} {statistics.mean(tok_sec_ok):>6.1f}")
log(f" {'Median':<22} {statistics.median(tok_sec_ok):>6.1f}")
log(f" {'Min':<22} {C.RED}{min(tok_sec_ok):>6.1f}{S.RESET_ALL}")
log(f" {'Max':<22} {C.GREEN}{max(tok_sec_ok):>6.1f}{S.RESET_ALL}")
print()
log(f" {C.CYAN}{'TOKEN COUNTS':<22}{S.RESET_ALL}", bright=True)
log(f" {'Total prompt tokens':<22} {total_prompt:>6}")
log(f" {'Total completion tok':<22} {total_completion:>6}")
log(f" {'Total combined':<22} {total_prompt + total_completion:>6}")
if ct_ok:
log(f" {'Avg completion/req':<22} {statistics.mean(ct_ok):>6.1f}")
print()
log(f" {C.CYAN}{'SYSTEM THROUGHPUT':<22}{S.RESET_ALL}", bright=True)
agg_tok_sec = total_completion / wall_time if wall_time > 0 else 0
req_per_min = len(ok_r) / (wall_time / 60) if wall_time > 0 else 0
log(f" {'Completion tok/s':<22} {C.GREEN}{agg_tok_sec:>6.1f}{S.RESET_ALL}")
log(f" {'Requests/min':<22} {req_per_min:>6.1f}")
print()
if fail_r:
log(f" {C.RED}{'ERROR BREAKDOWN':<22}{S.RESET_ALL}", bright=True)
counts = {}
for r in fail_r:
key = r["error"] or "UNKNOWN"
counts[key] = counts.get(key, 0) + 1
for err, cnt in sorted(counts.items(), key=lambda x: -x[1]):
log(f" {err:<30} {C.RED}{cnt}{S.RESET_ALL}")
# ------------------------------------------------------------------
# Findings & recommendations
# ------------------------------------------------------------------
print()
log("=" * 68, C.MAGENTA, bright=True)
log(" ANALYSIS", C.MAGENTA, bright=True)
log("=" * 68, C.MAGENTA, bright=True)
findings = []
if elapsed_ok:
single = min(elapsed_ok)
tail = max(elapsed_ok)
if tail > single * 1.5:
findings.append(
f"{C.YELLOW}• High tail latency: max {tail:.1f}s vs min {single:.1f}s. "
f"Requests queue behind each other.{S.RESET_ALL}"
)
if fail_r:
findings.append(
f"{C.RED}{len(fail_r)}/{len(results)} requests failed — "
f"check server health and resource limits.{S.RESET_ALL}"
)
findings.append(
f"{C.GREEN}• Server runs with 2 CPU cores, 2 parallel slots "
f"(--parallel 2), continuous batching enabled.{S.RESET_ALL}"
)
if agg_tok_speed := globals().get("agg_tok_sec", 0):
if agg_tok_speed < 5:
findings.append(
f"{C.YELLOW}• System throughput ({agg_tok_speed:.1f} tok/s) is low — "
f"limited by 2-core CPU constraint.{S.RESET_ALL}"
)
else:
findings.append(
f"{C.GREEN}• System throughput ({agg_tok_speed:.1f} tok/s) "
f"is acceptable for CPU inference.{S.RESET_ALL}"
)
for f_text in findings:
log(f" {f_text}")
print()
log("=" * 68, C.MAGENTA, bright=True)
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main():
parser = argparse.ArgumentParser(
description="BitNet b1.58 async stress test",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--url", default="http://localhost:8080/v1/chat/completions")
parser.add_argument("--requests", type=int, default=50)
parser.add_argument("--concurrency", type=int, default=10)
parser.add_argument("--timeout", type=int, default=120)
parser.add_argument("--warmup", type=int, default=0,
help="Number of warmup requests (results discarded)")
args = parser.parse_args()
print()
log("=" * 68, C.CYAN, bright=True)
log(" BitNet b1.58 Async Stress Test", C.CYAN, bright=True)
log(f" {args.requests} requests | {args.concurrency} concurrent "
f"| {args.warmup} warmup | {args.url}", C.CYAN)
log("=" * 68, C.CYAN, bright=True)
print()
connector = aiohttp.TCPConnector(limit=args.concurrency, force_close=True)
async with aiohttp.ClientSession(connector=connector) as session:
all_reqs = list(range(1, args.requests + 1 + args.warmup))
# Warmup phase
for req_id in all_reqs[:args.warmup]:
await do_request(session, req_id, args.url, args.timeout)
# Main test
wall_start = time.perf_counter()
tasks = [
asyncio.create_task(
do_request(session, req_id, args.url, args.timeout)
)
for req_id in all_reqs[args.warmup:]
]
results = await asyncio.gather(*tasks)
wall_time = time.perf_counter() - wall_start
print_audit(results, wall_time, args.url, args.requests,
args.concurrency, args.warmup)
if __name__ == "__main__":
asyncio.run(main())