Update 01_benchmark.py
Browse files- 01_benchmark.py +129 -0
01_benchmark.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""SimpleTool multi-head parallel decode β vLLM, v1/v2, external prompts
|
| 3 |
+
python 01_benchmark.py --version v2 # v2 default model
|
| 4 |
+
python 01_benchmark.py --version v1 # v1 default model
|
| 5 |
+
python 01_benchmark.py --version v2 --n-args 3 # fixed three arg heads
|
| 6 |
+
python 01_benchmark.py --version v1 --model /my/model # customed model path
|
| 7 |
+
"""
|
| 8 |
+
import argparse, json, time, os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
DIR = Path("./prompts")
|
| 12 |
+
HEADS = [("function","<function>","</function>")] + [(f"arg{i}",f"<arg{i}>",f"</arg{i}>") for i in range(1,7)]
|
| 13 |
+
STOPS = ["</function>"] + [f"</arg{i}>" for i in range(1,7)] + ["</content>","<|null|>","<|im_end|>"]
|
| 14 |
+
MODELS = {"v1":"./models/RT-Qwen3-4B-AWQ", "v2":"./models/RT-Qwen3-4B-AWQ-v2"}
|
| 15 |
+
|
| 16 |
+
def load_scenarios():
|
| 17 |
+
scs = json.loads((DIR/"scenarios.json").read_text())
|
| 18 |
+
for sc in scs:
|
| 19 |
+
sc["tools"] = (DIR/sc["tools_file"]).read_text().strip()
|
| 20 |
+
return scs
|
| 21 |
+
|
| 22 |
+
def max_tool_params(tools_str):
|
| 23 |
+
m = 0
|
| 24 |
+
for l in tools_str.strip().split("\n"):
|
| 25 |
+
try: m = max(m, len(json.loads(l)["function"]["parameters"]["properties"]))
|
| 26 |
+
except: pass
|
| 27 |
+
return m
|
| 28 |
+
|
| 29 |
+
def build_prompt(sc, ver):
|
| 30 |
+
t = sc["tools"]
|
| 31 |
+
if ver == "v1":
|
| 32 |
+
v1sys = (DIR/"v1_system.txt").read_text()
|
| 33 |
+
return (f"<|im_start|>system\n{v1sys}\n## Available Tools:\n\n{t}<|im_end|>\n"
|
| 34 |
+
f"<|im_start|>user\nenvironment: []\nhistory: {sc['history']}\n\n{sc['system']}\n\n{sc['query']}<|im_end|>\n"
|
| 35 |
+
f"<|im_start|>assistant\n")
|
| 36 |
+
return (f"<|im_start|>system\n{sc['system']}\n\n## Available Tools:\n\n{t}<|im_end|>\n"
|
| 37 |
+
f"<|im_start|>user\nhistory: {sc['history']}\n\n{sc['query']}<|im_end|>\n"
|
| 38 |
+
f"<|im_start|>assistant\n")
|
| 39 |
+
|
| 40 |
+
def clean(t):
|
| 41 |
+
t = t.strip()
|
| 42 |
+
return "<|null|>" if "<|null|>" in t or t == "" else t.split("</")[0].strip()
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
ap = argparse.ArgumentParser()
|
| 46 |
+
ap.add_argument("--model", default=None)
|
| 47 |
+
ap.add_argument("--version", default="v2", choices=["v1","v2"])
|
| 48 |
+
ap.add_argument("--n-args", default="auto")
|
| 49 |
+
ap.add_argument("--gpu", type=int, default=0)
|
| 50 |
+
ap.add_argument("--max-model-len", type=int, default=4096)
|
| 51 |
+
a = ap.parse_args()
|
| 52 |
+
a.model = a.model or MODELS[a.version]
|
| 53 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(a.gpu)
|
| 54 |
+
from vllm import LLM, SamplingParams
|
| 55 |
+
|
| 56 |
+
SC = load_scenarios()
|
| 57 |
+
print(f"\n{'='*60}\n {a.version} | {a.model}\n{'='*60}")
|
| 58 |
+
llm = LLM(model=a.model, trust_remote_code=True, dtype="auto", gpu_memory_utilization=0.80,
|
| 59 |
+
max_model_len=a.max_model_len, max_num_seqs=8, enable_prefix_caching=True)
|
| 60 |
+
sp = SamplingParams(temperature=0.0, max_tokens=128, stop=STOPS, include_stop_str_in_output=True)
|
| 61 |
+
na = [min(max_tool_params(s["tools"]),6) if a.n_args=="auto" else max(1,min(6,int(a.n_args))) for s in SC]
|
| 62 |
+
for s,n in zip(SC,na): print(f" {s['name']:<35} heads={1+n}")
|
| 63 |
+
|
| 64 |
+
def run(sc, n):
|
| 65 |
+
hd = HEADS[:1+n]; base = build_prompt(sc, a.version)
|
| 66 |
+
t0 = time.perf_counter()
|
| 67 |
+
outs = llm.generate([base+op for _,op,_ in hd], sp)
|
| 68 |
+
ms = (time.perf_counter()-t0)*1000
|
| 69 |
+
raw, toks, full = {}, {}, {}
|
| 70 |
+
for j,(nm,_,_) in enumerate(hd):
|
| 71 |
+
if j<len(outs) and outs[j].outputs:
|
| 72 |
+
o = outs[j].outputs[0]; full[nm]=o.text; raw[nm]=clean(o.text); toks[nm]=len(o.token_ids)
|
| 73 |
+
else: raw[nm],toks[nm],full[nm] = "<|null|>",0,""
|
| 74 |
+
return raw, toks, full, ms, hd
|
| 75 |
+
|
| 76 |
+
# Cold
|
| 77 |
+
print(f"\n{'='*60}\n COLD START\n{'='*60}")
|
| 78 |
+
cold = []
|
| 79 |
+
for i,s in enumerate(SC): _,_,_,ms,_=run(s,na[i]); cold.append(ms); print(f" {s['name']:<35} {ms:7.1f}ms")
|
| 80 |
+
|
| 81 |
+
# Hot x3
|
| 82 |
+
print(f"\n{'='*60}\n HOT WARMUP (3 rounds)\n{'='*60}")
|
| 83 |
+
hot = [[] for _ in SC]
|
| 84 |
+
for r in range(3):
|
| 85 |
+
for i,s in enumerate(SC): _,_,_,ms,_=run(s,na[i]); hot[i].append(ms)
|
| 86 |
+
print(f" Round {r+1}: "+" ".join(f"{hot[j][-1]:6.1f}ms" for j in range(len(SC))))
|
| 87 |
+
|
| 88 |
+
# Test
|
| 89 |
+
print(f"\n{'='*60}\n PARALLEL TEST ({a.version})\n{'='*60}\n")
|
| 90 |
+
res = []
|
| 91 |
+
for i,s in enumerate(SC):
|
| 92 |
+
raw,toks,full,ms,hd = run(s,na[i]); mt=max(toks.values()) if toks else 0
|
| 93 |
+
ok = raw.get("function","") == s["expected"]; res.append((s,raw,toks,full,ms,mt,hd,ok))
|
| 94 |
+
print(f"βββ {s['name']} βββ\n{'PASS' if ok else 'FAIL'} {s['desc']}")
|
| 95 |
+
for nm,_,_ in hd:
|
| 96 |
+
v,tc = raw.get(nm,""),toks.get(nm,0); d=v if len(v)<=43 else v[:43]+"β¦"
|
| 97 |
+
st = ("OK" if ok else f"WRONG({v})") if nm=="function" else ("NULL" if v=="<|null|>" else "FILL")
|
| 98 |
+
print(f" {nm:<10} {d:<45} {tc:<4} {st}")
|
| 99 |
+
print(f" e2e={ms:.1f}ms max_tok={mt}\n")
|
| 100 |
+
|
| 101 |
+
# Summary
|
| 102 |
+
N=len(res); np_=sum(r[7] for r in res); ae=sum(r[4] for r in res)/N; amt=sum(r[5] for r in res)/N
|
| 103 |
+
print(f"{'='*60}\n SUMMARY ({a.version})\n{'='*60}")
|
| 104 |
+
print(f" Accuracy : {np_}/{N}\n Cold start avg : {sum(cold)/N:.1f}ms\n Hot prefill avg: {sum(sum(h) for h in hot)/sum(len(h) for h in hot):.1f}ms")
|
| 105 |
+
print(f" E2E avg (hot) : {ae:.1f}ms\n Max head tokens: {amt:.1f} avg\n E2E / max_tok : {ae/amt:.1f}ms/tok (decode bottleneck)\n")
|
| 106 |
+
print(f" {'Scenario':<35} {'Cold':>7} {'Hot':>7} {'E2E':>7} {'MaxTk':>6} {'ms/tk':>6}\n {'β'*70}")
|
| 107 |
+
for i,(s,_,_,_,ms,mt,_,_) in enumerate(res):
|
| 108 |
+
print(f" {s['name']:<35} {cold[i]:6.1f} {sum(hot[i])/3:6.1f} {ms:6.1f} {mt:>5} {ms/mt if mt else 0:5.1f}")
|
| 109 |
+
|
| 110 |
+
# Example dump
|
| 111 |
+
s,raw,toks,full,ms,mt,hd,ok = res[0]; base=build_prompt(s,a.version)
|
| 112 |
+
print(f"\n{'='*60}\n EXAMPLE ({a.version}): {s['name']}\n{'='*60}")
|
| 113 |
+
print(f"\nββ Shared Prefix ({len(base)} chars) ββββββββββββββββββββ")
|
| 114 |
+
for ln in base.split("\n"): print(f"β {ln}")
|
| 115 |
+
print(f"βββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 116 |
+
print(f"\nββ Per-Head Trigger Tokens βββββββββββββββββββββββββ")
|
| 117 |
+
for nm,op,_ in hd: print(f"β {nm:<10} β {op}")
|
| 118 |
+
print(f"βββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 119 |
+
print(f"\nββ Decode Output (all tokens, incl. stop) ββββββββββ")
|
| 120 |
+
for nm,op,_ in hd: print(f"β {nm:<10} [{toks.get(nm,0):>2} tok] {op}{full.get(nm,'')}")
|
| 121 |
+
print(f"βββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 122 |
+
print(f"\n Reconstructed multi-head response:")
|
| 123 |
+
for nm,op,cl in hd:
|
| 124 |
+
if raw.get(nm,"")=="<|null|>": print(f" {op}<|null|>")
|
| 125 |
+
else:
|
| 126 |
+
ft=full.get(nm,""); print(f" {op}{ft}" if any(ft.rstrip().endswith(x) for x in STOPS) else f" {op}{ft}{cl}")
|
| 127 |
+
print()
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__": main()
|