Cialtion commited on
Commit
420ec60
Β·
verified Β·
1 Parent(s): 3218ed9

Update 01_benchmark.py

Browse files
Files changed (1) hide show
  1. 01_benchmark.py +129 -0
01_benchmark.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SimpleTool multi-head parallel decode β€” vLLM, v1/v2, external prompts
3
+ python 01_benchmark.py --version v2 # v2 default model
4
+ python 01_benchmark.py --version v1 # v1 default model
5
+ python 01_benchmark.py --version v2 --n-args 3 # fixed three arg heads
6
+ python 01_benchmark.py --version v1 --model /my/model # customed model path
7
+ """
8
+ import argparse, json, time, os
9
+ from pathlib import Path
10
+
11
+ DIR = Path("./prompts")
12
+ HEADS = [("function","<function>","</function>")] + [(f"arg{i}",f"<arg{i}>",f"</arg{i}>") for i in range(1,7)]
13
+ STOPS = ["</function>"] + [f"</arg{i}>" for i in range(1,7)] + ["</content>","<|null|>","<|im_end|>"]
14
+ MODELS = {"v1":"./models/RT-Qwen3-4B-AWQ", "v2":"./models/RT-Qwen3-4B-AWQ-v2"}
15
+
16
+ def load_scenarios():
17
+ scs = json.loads((DIR/"scenarios.json").read_text())
18
+ for sc in scs:
19
+ sc["tools"] = (DIR/sc["tools_file"]).read_text().strip()
20
+ return scs
21
+
22
+ def max_tool_params(tools_str):
23
+ m = 0
24
+ for l in tools_str.strip().split("\n"):
25
+ try: m = max(m, len(json.loads(l)["function"]["parameters"]["properties"]))
26
+ except: pass
27
+ return m
28
+
29
+ def build_prompt(sc, ver):
30
+ t = sc["tools"]
31
+ if ver == "v1":
32
+ v1sys = (DIR/"v1_system.txt").read_text()
33
+ return (f"<|im_start|>system\n{v1sys}\n## Available Tools:\n\n{t}<|im_end|>\n"
34
+ f"<|im_start|>user\nenvironment: []\nhistory: {sc['history']}\n\n{sc['system']}\n\n{sc['query']}<|im_end|>\n"
35
+ f"<|im_start|>assistant\n")
36
+ return (f"<|im_start|>system\n{sc['system']}\n\n## Available Tools:\n\n{t}<|im_end|>\n"
37
+ f"<|im_start|>user\nhistory: {sc['history']}\n\n{sc['query']}<|im_end|>\n"
38
+ f"<|im_start|>assistant\n")
39
+
40
+ def clean(t):
41
+ t = t.strip()
42
+ return "<|null|>" if "<|null|>" in t or t == "" else t.split("</")[0].strip()
43
+
44
+ def main():
45
+ ap = argparse.ArgumentParser()
46
+ ap.add_argument("--model", default=None)
47
+ ap.add_argument("--version", default="v2", choices=["v1","v2"])
48
+ ap.add_argument("--n-args", default="auto")
49
+ ap.add_argument("--gpu", type=int, default=0)
50
+ ap.add_argument("--max-model-len", type=int, default=4096)
51
+ a = ap.parse_args()
52
+ a.model = a.model or MODELS[a.version]
53
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(a.gpu)
54
+ from vllm import LLM, SamplingParams
55
+
56
+ SC = load_scenarios()
57
+ print(f"\n{'='*60}\n {a.version} | {a.model}\n{'='*60}")
58
+ llm = LLM(model=a.model, trust_remote_code=True, dtype="auto", gpu_memory_utilization=0.80,
59
+ max_model_len=a.max_model_len, max_num_seqs=8, enable_prefix_caching=True)
60
+ sp = SamplingParams(temperature=0.0, max_tokens=128, stop=STOPS, include_stop_str_in_output=True)
61
+ na = [min(max_tool_params(s["tools"]),6) if a.n_args=="auto" else max(1,min(6,int(a.n_args))) for s in SC]
62
+ for s,n in zip(SC,na): print(f" {s['name']:<35} heads={1+n}")
63
+
64
+ def run(sc, n):
65
+ hd = HEADS[:1+n]; base = build_prompt(sc, a.version)
66
+ t0 = time.perf_counter()
67
+ outs = llm.generate([base+op for _,op,_ in hd], sp)
68
+ ms = (time.perf_counter()-t0)*1000
69
+ raw, toks, full = {}, {}, {}
70
+ for j,(nm,_,_) in enumerate(hd):
71
+ if j<len(outs) and outs[j].outputs:
72
+ o = outs[j].outputs[0]; full[nm]=o.text; raw[nm]=clean(o.text); toks[nm]=len(o.token_ids)
73
+ else: raw[nm],toks[nm],full[nm] = "<|null|>",0,""
74
+ return raw, toks, full, ms, hd
75
+
76
+ # Cold
77
+ print(f"\n{'='*60}\n COLD START\n{'='*60}")
78
+ cold = []
79
+ for i,s in enumerate(SC): _,_,_,ms,_=run(s,na[i]); cold.append(ms); print(f" {s['name']:<35} {ms:7.1f}ms")
80
+
81
+ # Hot x3
82
+ print(f"\n{'='*60}\n HOT WARMUP (3 rounds)\n{'='*60}")
83
+ hot = [[] for _ in SC]
84
+ for r in range(3):
85
+ for i,s in enumerate(SC): _,_,_,ms,_=run(s,na[i]); hot[i].append(ms)
86
+ print(f" Round {r+1}: "+" ".join(f"{hot[j][-1]:6.1f}ms" for j in range(len(SC))))
87
+
88
+ # Test
89
+ print(f"\n{'='*60}\n PARALLEL TEST ({a.version})\n{'='*60}\n")
90
+ res = []
91
+ for i,s in enumerate(SC):
92
+ raw,toks,full,ms,hd = run(s,na[i]); mt=max(toks.values()) if toks else 0
93
+ ok = raw.get("function","") == s["expected"]; res.append((s,raw,toks,full,ms,mt,hd,ok))
94
+ print(f"─── {s['name']} ───\n{'PASS' if ok else 'FAIL'} {s['desc']}")
95
+ for nm,_,_ in hd:
96
+ v,tc = raw.get(nm,""),toks.get(nm,0); d=v if len(v)<=43 else v[:43]+"…"
97
+ st = ("OK" if ok else f"WRONG({v})") if nm=="function" else ("NULL" if v=="<|null|>" else "FILL")
98
+ print(f" {nm:<10} {d:<45} {tc:<4} {st}")
99
+ print(f" e2e={ms:.1f}ms max_tok={mt}\n")
100
+
101
+ # Summary
102
+ N=len(res); np_=sum(r[7] for r in res); ae=sum(r[4] for r in res)/N; amt=sum(r[5] for r in res)/N
103
+ print(f"{'='*60}\n SUMMARY ({a.version})\n{'='*60}")
104
+ print(f" Accuracy : {np_}/{N}\n Cold start avg : {sum(cold)/N:.1f}ms\n Hot prefill avg: {sum(sum(h) for h in hot)/sum(len(h) for h in hot):.1f}ms")
105
+ print(f" E2E avg (hot) : {ae:.1f}ms\n Max head tokens: {amt:.1f} avg\n E2E / max_tok : {ae/amt:.1f}ms/tok (decode bottleneck)\n")
106
+ print(f" {'Scenario':<35} {'Cold':>7} {'Hot':>7} {'E2E':>7} {'MaxTk':>6} {'ms/tk':>6}\n {'─'*70}")
107
+ for i,(s,_,_,_,ms,mt,_,_) in enumerate(res):
108
+ print(f" {s['name']:<35} {cold[i]:6.1f} {sum(hot[i])/3:6.1f} {ms:6.1f} {mt:>5} {ms/mt if mt else 0:5.1f}")
109
+
110
+ # Example dump
111
+ s,raw,toks,full,ms,mt,hd,ok = res[0]; base=build_prompt(s,a.version)
112
+ print(f"\n{'='*60}\n EXAMPLE ({a.version}): {s['name']}\n{'='*60}")
113
+ print(f"\nβ”Œβ”€ Shared Prefix ({len(base)} chars) ────────────────────")
114
+ for ln in base.split("\n"): print(f"β”‚ {ln}")
115
+ print(f"└──────────────────────────────────────────────────")
116
+ print(f"\nβ”Œβ”€ Per-Head Trigger Tokens ─────────────────────────")
117
+ for nm,op,_ in hd: print(f"β”‚ {nm:<10} β†’ {op}")
118
+ print(f"└──────────────────────────────────────────────────")
119
+ print(f"\nβ”Œβ”€ Decode Output (all tokens, incl. stop) ──────────")
120
+ for nm,op,_ in hd: print(f"β”‚ {nm:<10} [{toks.get(nm,0):>2} tok] {op}{full.get(nm,'')}")
121
+ print(f"└──────────────────────────────────────────────────")
122
+ print(f"\n Reconstructed multi-head response:")
123
+ for nm,op,cl in hd:
124
+ if raw.get(nm,"")=="<|null|>": print(f" {op}<|null|>")
125
+ else:
126
+ ft=full.get(nm,""); print(f" {op}{ft}" if any(ft.rstrip().endswith(x) for x in STOPS) else f" {op}{ft}{cl}")
127
+ print()
128
+
129
+ if __name__ == "__main__": main()