a0y0346 commited on
Commit
af9b854
·
1 Parent(s): c30936f

Refactor benchmarks to use real model.config values

Browse files

- Add get_real_model_config() to extract config from model.config
- Refactor run_prefill_benchmark to use F.scaled_dot_product_attention
- Refactor run_decode_benchmark with proper KV cache and GQA support
- Update create_kv_cache_chart to use model.config (no constants)
- All config values now come from actual loaded models

Files changed (1) hide show
  1. src/prefill_decode.py +334 -88
src/prefill_decode.py CHANGED
@@ -24,6 +24,38 @@ from .attention_utils import (
24
  )
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def run_prefill_with_real_model(
28
  model,
29
  attention_layer,
@@ -86,6 +118,110 @@ def run_prefill_with_real_model(
86
  return result
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def run_decode_with_real_model(
90
  model,
91
  attention_layer,
@@ -193,6 +329,128 @@ def run_decode_with_real_model(
193
  }
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # Legacy function kept for backwards compatibility
197
  def simulate_prefill_attention(
198
  batch_size: int,
@@ -347,95 +605,59 @@ def run_prefill_decode_comparison(
347
  """
348
  Run full comparison between prefill and decode phases using REAL HuggingFace model.
349
 
350
- Loads the actual model, extracts the attention layer, and benchmarks
351
- real attention operations for both prefill and decode phases.
352
 
353
  Returns results dict, comparison chart, KV cache chart, and insight text.
354
  """
355
  if model_name not in MODEL_CONFIGS:
356
  return {"error": f"Unknown model: {model_name}"}, None, None, "Error: Unknown model"
357
 
358
- config = MODEL_CONFIGS[model_name]
 
 
 
 
359
 
360
  results = {
361
  "model": model_name,
362
  "context_length": context_length,
363
  "decode_tokens": decode_tokens,
364
- "config": config,
365
- "using_real_model": True,
366
  }
367
 
368
- try:
369
- # Load the REAL HuggingFace model
370
- model = load_model(model_name)
371
-
372
- # Extract attention layer from layer 0
373
- attention_layer = extract_attention_layer(model, layer_idx=0)
374
-
375
- # Get model attention info
376
- attn_info = get_model_attention_info(model)
377
- results["model_info"] = attn_info
378
-
379
- # Run prefill benchmarks with REAL model attention
380
- prefill_flash = run_prefill_with_real_model(
381
- model=model,
382
- attention_layer=attention_layer,
383
- seq_len=context_length,
384
- batch_size=1,
385
- use_flash=True,
386
- )
387
-
388
- prefill_math = run_prefill_with_real_model(
389
- model=model,
390
- attention_layer=attention_layer,
391
- seq_len=context_length,
392
- batch_size=1,
393
- use_flash=False,
394
- )
395
-
396
- # Run decode benchmarks with REAL model attention
397
- decode_flash = run_decode_with_real_model(
398
- model=model,
399
- attention_layer=attention_layer,
400
- kv_cache_len=context_length,
401
- num_tokens=decode_tokens,
402
- batch_size=1,
403
- use_flash=True,
404
- )
405
-
406
- decode_math = run_decode_with_real_model(
407
- model=model,
408
- attention_layer=attention_layer,
409
- kv_cache_len=context_length,
410
- num_tokens=decode_tokens,
411
- batch_size=1,
412
- use_flash=False,
413
- )
414
-
415
- except Exception as e:
416
- # Fallback to legacy mode if model loading fails
417
- results["using_real_model"] = False
418
- results["fallback_reason"] = str(e)[:100]
419
-
420
- num_heads = config["q_heads"]
421
- head_dim = config["head_dim"]
422
-
423
- prefill_flash = simulate_prefill_attention(
424
- batch_size=1, num_heads=num_heads, seq_len=context_length,
425
- head_dim=head_dim, use_flash=True,
426
- )
427
- prefill_math = simulate_prefill_attention(
428
- batch_size=1, num_heads=num_heads, seq_len=context_length,
429
- head_dim=head_dim, use_flash=False,
430
- )
431
- decode_flash = simulate_decode_attention(
432
- batch_size=1, num_heads=num_heads, kv_cache_len=context_length,
433
- head_dim=head_dim, num_tokens=decode_tokens, use_flash=True,
434
- )
435
- decode_math = simulate_decode_attention(
436
- batch_size=1, num_heads=num_heads, kv_cache_len=context_length,
437
- head_dim=head_dim, num_tokens=decode_tokens, use_flash=False,
438
- )
439
 
440
  results["prefill"] = {
441
  "flash": prefill_flash,
@@ -446,18 +668,27 @@ def run_prefill_decode_comparison(
446
  "math": decode_math,
447
  }
448
 
 
 
 
 
 
 
 
 
 
449
  # Create comparison chart
450
  comparison_chart = create_comparison_chart(results)
451
 
452
- # Create KV cache growth chart
453
- kv_cache_chart = create_kv_cache_chart(config, context_length, decode_tokens)
454
 
455
  # Generate insight
456
  insight = generate_phase_insight(results)
457
 
458
  # Add real model indicator to insight
459
- if results.get("using_real_model"):
460
- model_indicator = f"\n\n---\n\n*Benchmarked using real **{model_name}** attention layer ({attn_info['num_attention_heads']} heads, {attn_info['head_dim']}d)*"
461
  insight = insight + model_indicator
462
 
463
  return results, comparison_chart, kv_cache_chart, insight
@@ -558,16 +789,30 @@ def create_comparison_chart(results: dict) -> go.Figure:
558
  return fig
559
 
560
 
561
- def create_kv_cache_chart(config: dict, context_length: int, decode_tokens: int) -> go.Figure:
562
- """Create chart showing KV cache growth during generation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
- kv_heads = config["kv_heads"]
565
- head_dim = config["head_dim"]
566
- num_layers = config["layers"]
567
 
568
  # Calculate KV cache size at each step
569
- # KV cache per layer: 2 (K+V) × kv_heads × seq_len × head_dim × 2 (FP16 bytes)
570
- bytes_per_token_per_layer = 2 * kv_heads * head_dim * 2
571
  total_bytes_per_token = bytes_per_token_per_layer * num_layers
572
 
573
  # Generate sequence of token counts
@@ -620,7 +865,7 @@ def create_kv_cache_chart(config: dict, context_length: int, decode_tokens: int)
620
 
621
  fig.update_layout(
622
  title=dict(
623
- text=f"KV Cache Growth ({config.get('kv_heads', 'N/A')} KV heads × {num_layers} layers)",
624
  x=0.5,
625
  ),
626
  xaxis_title="Tokens Processed",
@@ -634,6 +879,7 @@ def create_kv_cache_chart(config: dict, context_length: int, decode_tokens: int)
634
  xanchor="center",
635
  x=0.5,
636
  ),
 
637
  )
638
 
639
  return fig
 
24
  )
25
 
26
 
27
+ def get_real_model_config(model_name: str) -> dict:
28
+ """
29
+ Load model and extract ACTUAL config values from model.config.
30
+
31
+ This function ensures we use real model architecture values,
32
+ NOT hardcoded constants from MODEL_CONFIGS.
33
+
34
+ Args:
35
+ model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M")
36
+
37
+ Returns:
38
+ Dict with real model configuration values
39
+ """
40
+ model = load_model(model_name)
41
+ config = model.config
42
+
43
+ # Extract values directly from model.config
44
+ num_heads = config.num_attention_heads
45
+ num_kv_heads = getattr(config, 'num_key_value_heads', num_heads)
46
+ head_dim = config.hidden_size // num_heads
47
+
48
+ return {
49
+ "num_layers": config.num_hidden_layers,
50
+ "num_heads": num_heads,
51
+ "num_kv_heads": num_kv_heads,
52
+ "head_dim": head_dim,
53
+ "hidden_size": config.hidden_size,
54
+ "model_type": getattr(config, 'model_type', 'unknown'),
55
+ "gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1,
56
+ }
57
+
58
+
59
  def run_prefill_with_real_model(
60
  model,
61
  attention_layer,
 
118
  return result
119
 
120
 
121
+ def run_prefill_benchmark(
122
+ model_name: str,
123
+ seq_len: int,
124
+ batch_size: int = 1,
125
+ num_iterations: int = 10,
126
+ use_flash: bool = True,
127
+ ) -> dict:
128
+ """
129
+ Benchmark prefill phase using F.scaled_dot_product_attention with REAL model dimensions.
130
+
131
+ This function uses the model's actual configuration (from model.config) to create
132
+ properly-sized Q, K, V tensors, then benchmarks the SDPA operation directly.
133
+ This is more reliable than calling attention layer forward() methods.
134
+
135
+ Args:
136
+ model_name: Key from MODEL_CONFIGS (model will be loaded to get real config)
137
+ seq_len: Sequence length (prompt tokens)
138
+ batch_size: Batch size
139
+ num_iterations: Number of timed iterations
140
+ use_flash: Whether to use FlashAttention backend
141
+
142
+ Returns:
143
+ Dict with time_ms, memory_mb, and status
144
+ """
145
+ if not torch.cuda.is_available():
146
+ return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"}
147
+
148
+ device = torch.device("cuda")
149
+ dtype = torch.float16
150
+
151
+ try:
152
+ # Get REAL config from loaded model
153
+ real_config = get_real_model_config(model_name)
154
+ num_heads = real_config["num_heads"]
155
+ head_dim = real_config["head_dim"]
156
+
157
+ # Create Q, K, V tensors with REAL model dimensions
158
+ # Shape: [batch, num_heads, seq_len, head_dim]
159
+ Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
160
+ K = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
161
+ V = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
162
+
163
+ # Set backend flags
164
+ if use_flash:
165
+ enable_math, enable_flash, enable_mem_efficient = False, True, False
166
+ else:
167
+ enable_math, enable_flash, enable_mem_efficient = True, False, False
168
+
169
+ # Warmup
170
+ for _ in range(3):
171
+ with torch.backends.cuda.sdp_kernel(
172
+ enable_flash=enable_flash,
173
+ enable_math=enable_math,
174
+ enable_mem_efficient=enable_mem_efficient
175
+ ):
176
+ _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
177
+
178
+ torch.cuda.synchronize()
179
+ torch.cuda.reset_peak_memory_stats()
180
+
181
+ # Timed runs
182
+ start = torch.cuda.Event(enable_timing=True)
183
+ end = torch.cuda.Event(enable_timing=True)
184
+
185
+ start.record()
186
+ for _ in range(num_iterations):
187
+ with torch.backends.cuda.sdp_kernel(
188
+ enable_flash=enable_flash,
189
+ enable_math=enable_math,
190
+ enable_mem_efficient=enable_mem_efficient
191
+ ):
192
+ output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
193
+ end.record()
194
+
195
+ torch.cuda.synchronize()
196
+
197
+ time_ms = start.elapsed_time(end) / num_iterations
198
+ memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
199
+
200
+ # Cleanup
201
+ del Q, K, V, output
202
+ torch.cuda.empty_cache()
203
+
204
+ return {
205
+ "time_ms": round(time_ms, 3),
206
+ "memory_mb": round(memory_mb, 1),
207
+ "seq_len": seq_len,
208
+ "phase": "prefill",
209
+ "backend": "flash" if use_flash else "math",
210
+ "num_heads": num_heads,
211
+ "head_dim": head_dim,
212
+ "status": "success",
213
+ "using_real_config": True,
214
+ }
215
+
216
+ except Exception as e:
217
+ return {
218
+ "time_ms": 0,
219
+ "memory_mb": 0,
220
+ "status": f"error: {str(e)[:100]}",
221
+ "phase": "prefill",
222
+ }
223
+
224
+
225
  def run_decode_with_real_model(
226
  model,
227
  attention_layer,
 
329
  }
330
 
331
 
332
+ def run_decode_benchmark(
333
+ model_name: str,
334
+ kv_cache_len: int,
335
+ num_tokens: int = 10,
336
+ batch_size: int = 1,
337
+ num_iterations: int = 5,
338
+ use_flash: bool = True,
339
+ ) -> dict:
340
+ """
341
+ Benchmark decode phase using F.scaled_dot_product_attention with REAL model dimensions.
342
+
343
+ Properly simulates decode by:
344
+ - Creating single query token (Q with seq_len=1)
345
+ - Creating KV cache tensors with kv_cache_len tokens
346
+ - Handling GQA by expanding KV heads to match Q heads
347
+
348
+ Args:
349
+ model_name: Key from MODEL_CONFIGS (model will be loaded to get real config)
350
+ kv_cache_len: Length of KV cache (context length)
351
+ num_tokens: Number of decode tokens to simulate
352
+ batch_size: Batch size
353
+ num_iterations: Iterations for timing
354
+ use_flash: Whether to use FlashAttention backend
355
+
356
+ Returns:
357
+ Dict with time_ms_per_token, memory_mb, and status
358
+ """
359
+ if not torch.cuda.is_available():
360
+ return {"time_ms_per_token": 0, "memory_mb": 0, "status": "error: CUDA not available"}
361
+
362
+ device = torch.device("cuda")
363
+ dtype = torch.float16
364
+
365
+ try:
366
+ # Get REAL config from loaded model
367
+ real_config = get_real_model_config(model_name)
368
+ num_heads = real_config["num_heads"]
369
+ num_kv_heads = real_config["num_kv_heads"]
370
+ head_dim = real_config["head_dim"]
371
+
372
+ # Single query token: [batch, num_heads, 1, head_dim]
373
+ Q = torch.randn(batch_size, num_heads, 1, head_dim, dtype=dtype, device=device)
374
+
375
+ # KV cache with real model's KV head count: [batch, num_kv_heads, kv_cache_len, head_dim]
376
+ K_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device)
377
+ V_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device)
378
+
379
+ # Handle GQA: expand KV heads to match Q heads if needed
380
+ if num_kv_heads < num_heads:
381
+ repeat_factor = num_heads // num_kv_heads
382
+ K_cache = K_cache.repeat_interleave(repeat_factor, dim=1)
383
+ V_cache = V_cache.repeat_interleave(repeat_factor, dim=1)
384
+
385
+ # Set backend flags
386
+ if use_flash:
387
+ enable_math, enable_flash_flag, enable_mem_efficient = False, True, False
388
+ else:
389
+ enable_math, enable_flash_flag, enable_mem_efficient = True, False, False
390
+
391
+ # Warmup
392
+ for _ in range(3):
393
+ with torch.backends.cuda.sdp_kernel(
394
+ enable_flash=enable_flash_flag,
395
+ enable_math=enable_math,
396
+ enable_mem_efficient=enable_mem_efficient
397
+ ):
398
+ _ = F.scaled_dot_product_attention(Q, K_cache, V_cache)
399
+
400
+ torch.cuda.synchronize()
401
+ torch.cuda.reset_peak_memory_stats()
402
+
403
+ # Timed runs - simulate generating num_tokens
404
+ start = torch.cuda.Event(enable_timing=True)
405
+ end = torch.cuda.Event(enable_timing=True)
406
+
407
+ start.record()
408
+ for _ in range(num_tokens * num_iterations):
409
+ with torch.backends.cuda.sdp_kernel(
410
+ enable_flash=enable_flash_flag,
411
+ enable_math=enable_math,
412
+ enable_mem_efficient=enable_mem_efficient
413
+ ):
414
+ output = F.scaled_dot_product_attention(Q, K_cache, V_cache)
415
+ end.record()
416
+
417
+ torch.cuda.synchronize()
418
+
419
+ total_time_ms = start.elapsed_time(end)
420
+ time_per_token_ms = total_time_ms / (num_tokens * num_iterations)
421
+ memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
422
+
423
+ # Cleanup
424
+ del Q, K_cache, V_cache, output
425
+ torch.cuda.empty_cache()
426
+
427
+ return {
428
+ "time_ms_per_token": round(time_per_token_ms, 4),
429
+ "total_time_ms": round(total_time_ms / num_iterations, 3),
430
+ "memory_mb": round(memory_mb, 1),
431
+ "kv_cache_len": kv_cache_len,
432
+ "num_tokens": num_tokens,
433
+ "phase": "decode",
434
+ "backend": "flash" if use_flash else "math",
435
+ "num_heads": num_heads,
436
+ "num_kv_heads": num_kv_heads,
437
+ "head_dim": head_dim,
438
+ "status": "success",
439
+ "using_real_config": True,
440
+ }
441
+
442
+ except Exception as e:
443
+ return {
444
+ "time_ms_per_token": 0,
445
+ "total_time_ms": 0,
446
+ "memory_mb": 0,
447
+ "kv_cache_len": kv_cache_len,
448
+ "num_tokens": num_tokens,
449
+ "phase": "decode",
450
+ "status": f"error: {str(e)[:100]}",
451
+ }
452
+
453
+
454
  # Legacy function kept for backwards compatibility
455
  def simulate_prefill_attention(
456
  batch_size: int,
 
605
  """
606
  Run full comparison between prefill and decode phases using REAL HuggingFace model.
607
 
608
+ Uses F.scaled_dot_product_attention with real model dimensions for reliable benchmarking.
609
+ All config values come from model.config, not constants.
610
 
611
  Returns results dict, comparison chart, KV cache chart, and insight text.
612
  """
613
  if model_name not in MODEL_CONFIGS:
614
  return {"error": f"Unknown model: {model_name}"}, None, None, "Error: Unknown model"
615
 
616
+ # Get REAL config from model.config (not constants)
617
+ try:
618
+ real_config = get_real_model_config(model_name)
619
+ except Exception as e:
620
+ return {"error": f"Failed to load model: {str(e)[:50]}"}, None, None, f"Error: {str(e)[:50]}"
621
 
622
  results = {
623
  "model": model_name,
624
  "context_length": context_length,
625
  "decode_tokens": decode_tokens,
626
+ "real_config": real_config,
627
+ "using_real_config": True,
628
  }
629
 
630
+ # Run prefill benchmarks using SDPA with REAL model dimensions
631
+ prefill_flash = run_prefill_benchmark(
632
+ model_name=model_name,
633
+ seq_len=context_length,
634
+ batch_size=1,
635
+ use_flash=True,
636
+ )
637
+
638
+ prefill_math = run_prefill_benchmark(
639
+ model_name=model_name,
640
+ seq_len=context_length,
641
+ batch_size=1,
642
+ use_flash=False,
643
+ )
644
+
645
+ # Run decode benchmarks using SDPA with proper KV cache simulation
646
+ decode_flash = run_decode_benchmark(
647
+ model_name=model_name,
648
+ kv_cache_len=context_length,
649
+ num_tokens=decode_tokens,
650
+ batch_size=1,
651
+ use_flash=True,
652
+ )
653
+
654
+ decode_math = run_decode_benchmark(
655
+ model_name=model_name,
656
+ kv_cache_len=context_length,
657
+ num_tokens=decode_tokens,
658
+ batch_size=1,
659
+ use_flash=False,
660
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
 
662
  results["prefill"] = {
663
  "flash": prefill_flash,
 
668
  "math": decode_math,
669
  }
670
 
671
+ # Add model info for display
672
+ results["model_info"] = {
673
+ "num_heads": real_config["num_heads"],
674
+ "num_kv_heads": real_config["num_kv_heads"],
675
+ "head_dim": real_config["head_dim"],
676
+ "num_layers": real_config["num_layers"],
677
+ "gqa_ratio": real_config["gqa_ratio"],
678
+ }
679
+
680
  # Create comparison chart
681
  comparison_chart = create_comparison_chart(results)
682
 
683
+ # Create KV cache growth chart using REAL model config
684
+ kv_cache_chart = create_kv_cache_chart(model_name, context_length, decode_tokens)
685
 
686
  # Generate insight
687
  insight = generate_phase_insight(results)
688
 
689
  # Add real model indicator to insight
690
+ if results.get("using_real_config"):
691
+ model_indicator = f"\n\n---\n\n*Benchmarked using real **{model_name}** config ({real_config['num_heads']} heads, {real_config['head_dim']}d, GQA {real_config['gqa_ratio']}:1)*"
692
  insight = insight + model_indicator
693
 
694
  return results, comparison_chart, kv_cache_chart, insight
 
789
  return fig
790
 
791
 
792
+ def create_kv_cache_chart(model_name: str, context_length: int, decode_tokens: int) -> go.Figure:
793
+ """
794
+ Create chart showing KV cache growth during generation.
795
+
796
+ Uses REAL model config values from model.config, not constants.
797
+
798
+ Args:
799
+ model_name: Model name to load config from
800
+ context_length: Number of context tokens (prefill)
801
+ decode_tokens: Number of decode tokens to generate
802
+
803
+ Returns:
804
+ Plotly figure showing KV cache growth
805
+ """
806
+ # Get REAL config from loaded model (no constants!)
807
+ real_config = get_real_model_config(model_name)
808
 
809
+ num_kv_heads = real_config["num_kv_heads"]
810
+ head_dim = real_config["head_dim"]
811
+ num_layers = real_config["num_layers"]
812
 
813
  # Calculate KV cache size at each step
814
+ # KV cache per layer: 2 (K+V) × kv_heads × head_dim × 2 (FP16 bytes)
815
+ bytes_per_token_per_layer = 2 * num_kv_heads * head_dim * 2
816
  total_bytes_per_token = bytes_per_token_per_layer * num_layers
817
 
818
  # Generate sequence of token counts
 
865
 
866
  fig.update_layout(
867
  title=dict(
868
+ text=f"KV Cache Growth ({num_kv_heads} KV heads × {num_layers} layers)",
869
  x=0.5,
870
  ),
871
  xaxis_title="Tokens Processed",
 
879
  xanchor="center",
880
  x=0.5,
881
  ),
882
+ yaxis=dict(rangemode='tozero'),
883
  )
884
 
885
  return fig