Cialtion commited on
Commit
707405f
·
verified ·
1 Parent(s): 6fd5da2

Upload vllm_rt_qwen_mobile_actions.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vllm_rt_qwen_mobile_actions.py +423 -0
vllm_rt_qwen_mobile_actions.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vLLM Prefill/Decode 分离性能测试脚本
3
+ =============================================
4
+ 核心监控指标:
5
+ - Prefill tokens 数量 (实际需要计算的)
6
+ - KV Cache 命中 tokens 数量
7
+ - Prefill 耗时 (ms)
8
+ - Decode 耗时 (ms per token)
9
+ - 每步 overhead 分析
10
+ """
11
+
12
+ import time
13
+ import uuid
14
+ import numpy as np
15
+ from typing import List, Dict, Any, Optional
16
+ from dataclasses import dataclass, field
17
+
18
+ from vllm import LLM, SamplingParams
19
+
20
+
21
+ @dataclass
22
+ class StepMetrics:
23
+ """单步指标"""
24
+ step_idx: int
25
+ step_type: str # "prefill" or "decode"
26
+ duration_ms: float
27
+ tokens_processed: int = 0 # prefill时是处理的token数,decode时是1
28
+
29
+
30
+ @dataclass
31
+ class RequestMetrics:
32
+ """单请求完整指标"""
33
+ request_id: str
34
+ tag: str
35
+
36
+ # Token 统计
37
+ total_prompt_tokens: int = 0
38
+ cached_tokens: int = 0 # KV cache 命中的
39
+ computed_tokens: int = 0 # 实际需要 prefill 的
40
+ output_tokens: int = 0
41
+
42
+ # 时间指标 (ms)
43
+ prefill_ms: float = 0.0
44
+ decode_total_ms: float = 0.0
45
+ decode_per_token_ms: float = 0.0
46
+ total_ms: float = 0.0
47
+
48
+ # 每步详情
49
+ steps: List[StepMetrics] = field(default_factory=list)
50
+
51
+ # 输出
52
+ output_text: str = ""
53
+ stop_reason: str = ""
54
+
55
+
56
+ def extract_cache_metrics(output) -> Dict[str, int]:
57
+ """
58
+ 从 vLLM RequestOutput 提取 cache 相关指标
59
+ 兼容不同 vLLM 版本
60
+ """
61
+ result = {
62
+ 'num_cached_tokens': 0,
63
+ 'num_computed_tokens': 0,
64
+ 'num_prompt_tokens': 0,
65
+ }
66
+
67
+ try:
68
+ # vLLM 0.5+ metrics
69
+ if hasattr(output, 'metrics') and output.metrics:
70
+ m = output.metrics
71
+ result['num_cached_tokens'] = getattr(m, 'num_cached_tokens', 0) or 0
72
+ result['num_computed_tokens'] = getattr(m, 'num_computed_tokens', 0) or 0
73
+ result['num_prompt_tokens'] = getattr(m, 'num_prompt_tokens', 0) or 0
74
+
75
+ # 备用:直接从 prompt_token_ids 获取
76
+ if result['num_prompt_tokens'] == 0 and hasattr(output, 'prompt_token_ids'):
77
+ result['num_prompt_tokens'] = len(output.prompt_token_ids)
78
+
79
+ except Exception:
80
+ pass
81
+
82
+ return result
83
+
84
+
85
+ class PDProfiler:
86
+ """Prefill/Decode 性能分析器"""
87
+
88
+ def __init__(self, llm: LLM):
89
+ self.llm = llm
90
+ self.engine = llm.llm_engine
91
+ self.tokenizer = llm.get_tokenizer()
92
+
93
+ def profile_request(
94
+ self,
95
+ prompt: str,
96
+ tag: str = "default",
97
+ sampling_params: Optional[SamplingParams] = None
98
+ ) -> RequestMetrics:
99
+ """
100
+ 分析单个请求的 P/D 性能
101
+ """
102
+ if sampling_params is None:
103
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=32)
104
+
105
+ metrics = RequestMetrics(
106
+ request_id=str(uuid.uuid4()),
107
+ tag=tag
108
+ )
109
+
110
+ # 计算 prompt tokens
111
+ prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
112
+ metrics.total_prompt_tokens = len(prompt_tokens)
113
+
114
+ # 提交请求
115
+ self.engine.add_request(metrics.request_id, prompt, sampling_params)
116
+
117
+ # Step 循环,记录每步耗时
118
+ step_idx = 0
119
+ start_time = time.perf_counter()
120
+ prev_output_len = 0
121
+
122
+ while self.engine.has_unfinished_requests():
123
+ step_start = time.perf_counter()
124
+ outputs = self.engine.step()
125
+ step_end = time.perf_counter()
126
+ step_ms = (step_end - step_start) * 1000
127
+
128
+ for out in outputs:
129
+ if out.request_id != metrics.request_id:
130
+ continue
131
+
132
+ # 判断是 prefill 还是 decode
133
+ current_output_len = len(out.outputs[0].token_ids) if out.outputs else 0
134
+
135
+ if step_idx == 0:
136
+ # 第一步是 prefill
137
+ step_type = "prefill"
138
+ tokens_in_step = metrics.total_prompt_tokens
139
+ else:
140
+ step_type = "decode"
141
+ tokens_in_step = current_output_len - prev_output_len
142
+
143
+ metrics.steps.append(StepMetrics(
144
+ step_idx=step_idx,
145
+ step_type=step_type,
146
+ duration_ms=step_ms,
147
+ tokens_processed=tokens_in_step
148
+ ))
149
+
150
+ prev_output_len = current_output_len
151
+
152
+ # 请求完成时提取最终指标
153
+ if out.finished:
154
+ cache_info = extract_cache_metrics(out)
155
+ metrics.cached_tokens = cache_info['num_cached_tokens']
156
+ metrics.computed_tokens = cache_info['num_computed_tokens']
157
+
158
+ # 如果 vLLM 没返回 computed_tokens,手动计算
159
+ if metrics.computed_tokens == 0 and metrics.cached_tokens > 0:
160
+ metrics.computed_tokens = metrics.total_prompt_tokens - metrics.cached_tokens
161
+ elif metrics.computed_tokens == 0 and metrics.cached_tokens == 0:
162
+ metrics.computed_tokens = metrics.total_prompt_tokens
163
+
164
+ metrics.output_tokens = current_output_len
165
+ metrics.output_text = out.outputs[0].text if out.outputs else ""
166
+ metrics.stop_reason = str(getattr(out.outputs[0], 'finish_reason', '')) if out.outputs else ""
167
+
168
+ step_idx += 1
169
+
170
+ metrics.total_ms = (time.perf_counter() - start_time) * 1000
171
+
172
+ # 汇总时间指标
173
+ prefill_steps = [s for s in metrics.steps if s.step_type == "prefill"]
174
+ decode_steps = [s for s in metrics.steps if s.step_type == "decode"]
175
+
176
+ metrics.prefill_ms = sum(s.duration_ms for s in prefill_steps)
177
+ metrics.decode_total_ms = sum(s.duration_ms for s in decode_steps)
178
+
179
+ if metrics.output_tokens > 0:
180
+ metrics.decode_per_token_ms = metrics.decode_total_ms / metrics.output_tokens
181
+
182
+ return metrics
183
+
184
+ def warmup(self, prompt: str):
185
+ """预热,确保 KV cache 被填充"""
186
+ self.engine.add_request(
187
+ str(uuid.uuid4()),
188
+ prompt,
189
+ SamplingParams(max_tokens=1)
190
+ )
191
+ while self.engine.has_unfinished_requests():
192
+ self.engine.step()
193
+
194
+
195
+ def print_metrics_table(metrics_list: List[RequestMetrics], title: str = ""):
196
+ """打印性能指标表格"""
197
+
198
+ print(f"\n{'='*120}")
199
+ if title:
200
+ print(f" {title}")
201
+ print(f"{'='*120}")
202
+
203
+ # 表头
204
+ headers = [
205
+ "Tag", "PromptTok", "Cached", "Computed", "OutTok",
206
+ "Prefill(ms)", "Decode(ms)", "Dec/Tok(ms)", "Total(ms)", "Output"
207
+ ]
208
+ widths = [12, 10, 8, 10, 8, 12, 12, 12, 10, 30]
209
+
210
+ header_line = " | ".join(f"{h:<{w}}" for h, w in zip(headers, widths))
211
+ print(header_line)
212
+ print("-" * 120)
213
+
214
+ for m in metrics_list:
215
+ output_preview = m.output_text[:28].replace('\n', '\\n') + "..." if len(m.output_text) > 28 else m.output_text.replace('\n', '\\n')
216
+
217
+ row = [
218
+ m.tag[:12],
219
+ str(m.total_prompt_tokens),
220
+ str(m.cached_tokens),
221
+ str(m.computed_tokens),
222
+ str(m.output_tokens),
223
+ f"{m.prefill_ms:.2f}",
224
+ f"{m.decode_total_ms:.2f}",
225
+ f"{m.decode_per_token_ms:.2f}",
226
+ f"{m.total_ms:.2f}",
227
+ output_preview
228
+ ]
229
+
230
+ row_line = " | ".join(f"{v:<{w}}" for v, w in zip(row, widths))
231
+ print(row_line)
232
+
233
+ print("-" * 120)
234
+
235
+ # 汇总统计
236
+ if len(metrics_list) > 1:
237
+ avg_prefill = np.mean([m.prefill_ms for m in metrics_list])
238
+ avg_decode_per_tok = np.mean([m.decode_per_token_ms for m in metrics_list if m.decode_per_token_ms > 0])
239
+ total_cached = sum(m.cached_tokens for m in metrics_list)
240
+ total_computed = sum(m.computed_tokens for m in metrics_list)
241
+ cache_hit_rate = total_cached / (total_cached + total_computed) * 100 if (total_cached + total_computed) > 0 else 0
242
+
243
+ print(f"[Summary] Avg Prefill: {avg_prefill:.2f}ms | Avg Decode/Tok: {avg_decode_per_tok:.2f}ms | Cache Hit Rate: {cache_hit_rate:.1f}%")
244
+
245
+
246
+ def print_step_details(metrics: RequestMetrics):
247
+ """打印单请求的每步详情"""
248
+ print(f"\n[Step Details for '{metrics.tag}']")
249
+ print(f" {'Step':<6} {'Type':<10} {'Duration(ms)':<14} {'Tokens':<8}")
250
+ print(f" {'-'*40}")
251
+ for s in metrics.steps:
252
+ print(f" {s.step_idx:<6} {s.step_type:<10} {s.duration_ms:<14.2f} {s.tokens_processed:<8}")
253
+
254
+
255
+ # ==============================================================================
256
+ # 主测试
257
+ # ==============================================================================
258
+
259
+ def main():
260
+ # ==================== 配置 ====================
261
+ MODEL_PATH = "./RT-Qwen3-4B-AWQ"
262
+
263
+ print("="*60)
264
+ print(" vLLM Prefill/Decode 分离性能测试")
265
+ print("="*60)
266
+
267
+ # 初始化 LLM
268
+ print("\n[Init] Loading model...")
269
+ llm = LLM(
270
+ model=MODEL_PATH,
271
+ trust_remote_code=True,
272
+ enable_prefix_caching=True,
273
+ tensor_parallel_size=1,
274
+ max_num_seqs=16,
275
+ gpu_memory_utilization=0.8,
276
+ enforce_eager=False,
277
+ block_size=16,
278
+ max_model_len=8192,
279
+ # 关键:增大 chunk size 减少 chunked prefill 开销
280
+ # max_num_batched_tokens=8192, # 可选:设置更大的值
281
+ )
282
+
283
+ profiler = PDProfiler(llm)
284
+ tokenizer = llm.get_tokenizer()
285
+
286
+ # ==================== System Prompt ====================
287
+ system_prompt = (
288
+ "<|im_start|>system\n"
289
+ "You are a multi-head parallel function calling model. \n"
290
+ "## Output Heads\n\n"
291
+ "**Head 0 - <content>**: Natural language response\n"
292
+ "- Format: <content>response text</content>\n"
293
+ "- Answer what you want to say while you are calling a function\n\n"
294
+ "**Head 1 - <function>**: Function names to call\n"
295
+ "- Format: <function>name</function>\n"
296
+ "- Name: must match tool defined name\n\n"
297
+ "**Head 2-7 - <arg1>、<arg2>、<arg3>、<arg4>、<arg5>、<arg6>**: Function arguments by position\n"
298
+ "- Format: <argN>value</argN> \n"
299
+ "- Strictly fill in according to the parameter order of the tool you intend to call\n"
300
+ "- Note the special restrictions of parameter definitions for corresponding positions\n"
301
+ "- If the corresponding tool definition has required parameters, these must be filled in\n"
302
+ "- Infer the user's actual needs.\n"
303
+ "- If Unnecessary: <argN><|null|></argN>\n\n"
304
+ "**Environment - The information you have.\n**History - The tools you have called.\n\n"
305
+ "## Available Tools:\n\n"
306
+ '{"type": "function", "function": {"name": "open_wifi_settings", "description": "Opens the Wi-Fi settings.", "parameters": {"type": "object", "properties": {}}}}\n'
307
+ '{"type": "function", "function": {"name": "create_contact", "description": "Creates a contact in the phone\'s contact list.", "parameters": {"type": "object", "properties": {"first_name": {"type": "string", "description": "The first name of the contact."}, "last_name": {"type": "string", "description": "The last name of the contact."}, "email": {"type": "string", "description": "The email address of the contact.", "optional": true}, "phone_number": {"type": "string", "description": "The phone number of the contact.", "optional": true}}, "required": ["first_name", "last_name"]}}}\n'
308
+ '{"type": "function", "function": {"name": "show_map", "description": "Shows a location on the map.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The location to search for. May be the name of a place, a business, or an address."}}, "required": ["query"]}}}\n'
309
+ '{"type": "function", "function": {"name": "create_calendar_event", "description": "Creates a new calendar event.", "parameters": {"type": "object", "properties": {"title": {"type": "string", "description": "The title of the event."}, "datetime": {"type": "string", "description": "The date and time of the event in the format YYYY-MM-DDTHH:MM:SS."}}, "required": ["title", "datetime"]}}}\n'
310
+ '{"type": "function", "function": {"name": "send_email", "description": "Sends an email.", "parameters": {"type": "object", "properties": {"to": {"type": "string", "description": "The email address of the recipient."}, "subject": {"type": "string", "description": "The subject of the email."}, "body": {"type": "string", "description": "The body of the email.", "optional": true}}, "required": ["to", "subject"]}}}\n'
311
+ '{"type": "function", "function": {"name": "turn_off_flashlight", "description": "Turns the flashlight off.", "parameters": {"type": "object", "properties": {}}}}\n'
312
+ '{"type": "function", "function": {"name": "turn_on_flashlight", "description": "Turns the flashlight on.", "parameters": {"type": "object", "properties": {}}}}\n'
313
+ "<|im_end|>\n"
314
+ )
315
+
316
+ system_tokens = len(tokenizer.encode(system_prompt, add_special_tokens=False))
317
+ print(f"[System Prompt] {system_tokens} tokens")
318
+
319
+ # ==================== 测试用例 ====================
320
+ test_queries = [
321
+ ("create_contact", "<|im_start|>user\nenvironment: [\"No develop information provided\"]\nhistory: []\n\nCan you please save a new contact for me? The name is Lena Petrova, the phone number is +359 888 123 456, and the email is lena.petrova.design@webmail.com.<|im_end|>\n<|im_start|>assistant\n"),
322
+ ("send_email", "<|im_start|>user\nenvironment: [\"No develop information provided\"]\nhistory: []\n\nPlease send an email to javier.ortega@ecotradeintl.com with the subject 'Update on Q4 Report' and the body 'I've uploaded the revised figures to the shared drive.'<|im_end|>\n<|im_start|>assistant\n"),
323
+ ("calendar", "<|im_start|>user\nenvironment: [\"No develop information provided\"]\nhistory: []\n\nPlease set up a new calendar event for 'Team Lunch with Marketing' on May 13, 2025 at 1:30 PM.<|im_end|>\n<|im_start|>assistant\n"),
324
+ ("flashlight_map", "<|im_start|>user\nenvironment: [\"No develop information provided\"]\nhistory: []\n\nTurn on the flashlight and show me the location of the Sunnyvale Library on the map.<|im_end|>\n<|im_start|>assistant\n"),
325
+ ("flashlight_map_w_history", "<|im_start|>user\nenvironment: [\"No develop information provided\"]\nhistory: [turn_on_flashlight()]\n\nTurn on the flashlight and show me the location of the Sunnyvale Library on the map.<|im_end|>\n<|im_start|>assistant\n"),
326
+ ]
327
+
328
+ head_tags = ["<function>", "<arg1>", "<arg2>", "<arg3>", "<arg4>", "<arg5>"]
329
+
330
+ stop_tokens = [
331
+ "<|null|>", "</content>", "</function>",
332
+ "</arg1>", "</arg2>", "</arg3>", "</arg4>", "</arg5>", "</arg6>"
333
+ ]
334
+
335
+ sampling_params = SamplingParams(
336
+ temperature=0.0,
337
+ max_tokens=16,
338
+ stop=stop_tokens,
339
+ include_stop_str_in_output=True
340
+ )
341
+
342
+ # ==================== 测试循环 ====================
343
+ for query_idx, (query_name, query) in enumerate(test_queries):
344
+ print(f"\n{'#'*80}")
345
+ print(f"# ROUND {query_idx + 1}: {query_name}")
346
+ print(f"{'#'*80}")
347
+
348
+ full_prefix = system_prompt + query
349
+ prefix_tokens = len(tokenizer.encode(full_prefix, add_special_tokens=False))
350
+ query_tokens = prefix_tokens - system_tokens
351
+
352
+ print(f"[Prefix] System: {system_tokens} + Query: {query_tokens} = Total: {prefix_tokens} tokens")
353
+
354
+ # ---------------------------------------------------------
355
+ # 1. 冷启动 Warmup (填充 KV cache)
356
+ # ---------------------------------------------------------
357
+ print(f"\n--- Phase 1: Cold Start Warmup ---")
358
+ warmup_metrics = profiler.profile_request(
359
+ full_prefix,
360
+ tag="warmup",
361
+ sampling_params=SamplingParams(max_tokens=1)
362
+ )
363
+ print(f"[Warmup] Prefill {warmup_metrics.computed_tokens} tokens in {warmup_metrics.prefill_ms:.2f}ms")
364
+ print(f" Tokens/sec: {warmup_metrics.computed_tokens / warmup_metrics.prefill_ms * 1000:.0f}")
365
+
366
+ # ---------------------------------------------------------
367
+ # 2. 热启动测试 (KV cache 应该命中)
368
+ # ---------------------------------------------------------
369
+ print(f"\n--- Phase 2: Hot Start (Cache Hit Expected) ---")
370
+ all_metrics = []
371
+
372
+ for head_tag in head_tags:
373
+ head_prompt = full_prefix + head_tag
374
+ metrics = profiler.profile_request(
375
+ head_prompt,
376
+ tag=head_tag,
377
+ sampling_params=sampling_params
378
+ )
379
+ all_metrics.append(metrics)
380
+
381
+ # 打印表格
382
+ print_metrics_table(all_metrics, f"Round {query_idx + 1}: {query_name}")
383
+
384
+ # 打印第一个 head 的步骤详情
385
+ if all_metrics:
386
+ print_step_details(all_metrics[0])
387
+
388
+ # ==================== 额外测试:不同序列长度的冷启动性能 ====================
389
+ print(f"\n{'='*80}")
390
+ print(" BONUS: Cold Start Prefill Performance vs Sequence Length")
391
+ print(f"{'='*80}")
392
+
393
+ # 构造不同长度的 prompt
394
+ base_prompt = system_prompt
395
+ padding_text = "This is padding text to test prefill performance. " * 10
396
+
397
+ length_tests = []
398
+ for target_len in [256, 512, 1024, 2048]:
399
+ # 构造指定长度的 prompt
400
+ test_prompt = base_prompt
401
+ current_len = len(tokenizer.encode(test_prompt, add_special_tokens=False))
402
+
403
+ while current_len < target_len:
404
+ test_prompt += padding_text
405
+ current_len = len(tokenizer.encode(test_prompt, add_special_tokens=False))
406
+
407
+ # 冷启动测试 (新 prompt,无 cache)
408
+ metrics = profiler.profile_request(
409
+ test_prompt + f"<unique_{uuid.uuid4().hex[:8]}>", # 确保无 cache
410
+ tag=f"len_{target_len}",
411
+ sampling_params=SamplingParams(max_tokens=1)
412
+ )
413
+ length_tests.append(metrics)
414
+
415
+ throughput = metrics.computed_tokens / metrics.prefill_ms * 1000 if metrics.prefill_ms > 0 else 0
416
+ print(f"[Seq {target_len:4d}] Prefill: {metrics.prefill_ms:7.2f}ms | Computed: {metrics.computed_tokens:4d} | Throughput: {throughput:8.0f} tok/s")
417
+
418
+ print("\n[Analysis] If Prefill time grows faster than linearly, chunked prefill overhead is significant.")
419
+ print("[Tip] Try increasing max_num_batched_tokens or disabling chunked prefill for latency-critical workloads.")
420
+
421
+
422
+ if __name__ == "__main__":
423
+ main()