a0y0346 commited on
Commit
509d7b6
·
1 Parent(s): 6492c04

Add separate KV Cache dtype selector (FP16/BF16/FP8/INT8)

Browse files

- KV cache precision can now be configured independently from weight precision
- Realistic representation: model weights and KV cache have different quantization options
- FP8/INT8 KV cache halves memory vs FP16, enabling 2x longer contexts
- Updated insight panel to show KV cache memory savings
- Added more batch sizes (64, 128) to tradeoff chart

Files changed (2) hide show
  1. app.py +9 -3
  2. src/memory_budget.py +45 -20
app.py CHANGED
@@ -597,6 +597,12 @@ def create_app() -> gr.Blocks:
597
  label="Weight Precision",
598
  info="Data type for model weights",
599
  )
 
 
 
 
 
 
600
 
601
  with gr.Row():
602
  budget_batch = gr.Slider(
@@ -629,11 +635,11 @@ def create_app() -> gr.Blocks:
629
 
630
  # Memory analysis callback
631
  @spaces.GPU(duration=60)
632
- def analyze_memory(model_name: str, seq_len: int, batch_size: int, dtype: str):
633
  """Run memory analysis on selected model."""
634
  try:
635
  budget, breakdown_fig, scaling_fig, tradeoff_fig, insight = run_memory_analysis(
636
- model_name, int(seq_len), int(batch_size), dtype
637
  )
638
  status = f"**{model_name}**: {budget['breakdown']['total_gb']:.2f} GB total ({budget['utilization_pct']:.1f}% GPU)"
639
  return status, breakdown_fig, scaling_fig, tradeoff_fig, insight
@@ -643,7 +649,7 @@ def create_app() -> gr.Blocks:
643
 
644
  calculate_btn.click(
645
  fn=analyze_memory,
646
- inputs=[budget_model, budget_seq, budget_batch, budget_dtype],
647
  outputs=[budget_status, breakdown_chart, scaling_chart, tradeoff_chart, budget_insight],
648
  )
649
 
 
597
  label="Weight Precision",
598
  info="Data type for model weights",
599
  )
600
+ budget_kv_dtype = gr.Dropdown(
601
+ choices=["FP16", "BF16", "FP8", "INT8"],
602
+ value="FP16",
603
+ label="KV Cache Precision",
604
+ info="Data type for KV cache (FP8/INT8 need compatible hardware)",
605
+ )
606
 
607
  with gr.Row():
608
  budget_batch = gr.Slider(
 
635
 
636
  # Memory analysis callback
637
  @spaces.GPU(duration=60)
638
+ def analyze_memory(model_name: str, seq_len: int, batch_size: int, dtype: str, kv_dtype: str):
639
  """Run memory analysis on selected model."""
640
  try:
641
  budget, breakdown_fig, scaling_fig, tradeoff_fig, insight = run_memory_analysis(
642
+ model_name, int(seq_len), int(batch_size), dtype, kv_dtype
643
  )
644
  status = f"**{model_name}**: {budget['breakdown']['total_gb']:.2f} GB total ({budget['utilization_pct']:.1f}% GPU)"
645
  return status, breakdown_fig, scaling_fig, tradeoff_fig, insight
 
649
 
650
  calculate_btn.click(
651
  fn=analyze_memory,
652
+ inputs=[budget_model, budget_seq, budget_batch, budget_dtype, budget_kv_dtype],
653
  outputs=[budget_status, breakdown_chart, scaling_chart, tradeoff_chart, budget_insight],
654
  )
655
 
src/memory_budget.py CHANGED
@@ -129,11 +129,21 @@ def calculate_kv_cache_memory(
129
  }
130
 
131
 
 
 
 
 
 
 
 
 
 
132
  def calculate_memory_budget(
133
  model_name: str,
134
  seq_len: int,
135
  batch_size: int = 1,
136
  dtype: str = "FP16",
 
137
  ) -> dict:
138
  """
139
  Calculate complete memory budget using REAL model and GPU info.
@@ -143,6 +153,7 @@ def calculate_memory_budget(
143
  seq_len: Target sequence length
144
  batch_size: Batch size
145
  dtype: Data type for model weights
 
146
 
147
  Returns:
148
  Complete memory budget breakdown
@@ -158,14 +169,15 @@ def calculate_memory_budget(
158
  model_weights_bytes = model_info["num_parameters"] * dtype_bytes
159
  model_weights_gb = model_weights_bytes / (1024 ** 3)
160
 
161
- # KV cache memory
 
162
  kv_cache = calculate_kv_cache_memory(
163
  num_kv_heads=model_info["num_kv_heads"],
164
  head_dim=model_info["head_dim"],
165
  num_layers=model_info["num_layers"],
166
  seq_len=seq_len,
167
  batch_size=batch_size,
168
- dtype_bytes=int(dtype_bytes) if dtype_bytes >= 1 else 2, # KV cache usually FP16
169
  )
170
  kv_cache_gb = kv_cache["gb"]
171
 
@@ -187,6 +199,7 @@ def calculate_memory_budget(
187
  "model_info": model_info,
188
  "gpu_info": gpu_info,
189
  "dtype": dtype,
 
190
  "seq_len": seq_len,
191
  "batch_size": batch_size,
192
  "breakdown": {
@@ -206,6 +219,7 @@ def calculate_max_context_length(
206
  model_name: str,
207
  batch_size: int = 1,
208
  dtype: str = "FP16",
 
209
  memory_reserve_pct: float = 10.0,
210
  ) -> dict:
211
  """
@@ -217,6 +231,7 @@ def calculate_max_context_length(
217
  model_name: Model to analyze
218
  batch_size: Batch size
219
  dtype: Data type for weights
 
220
  memory_reserve_pct: Percentage to reserve for activations/overhead
221
 
222
  Returns:
@@ -251,7 +266,7 @@ def calculate_max_context_length(
251
 
252
  # Calculate max seq_len from available KV cache memory
253
  # KV cache bytes = 2 × num_kv_heads × head_dim × seq_len × num_layers × batch × dtype_bytes
254
- kv_dtype_bytes = 2 # KV cache typically FP16
255
  bytes_per_token = (
256
  2 * model_info["num_kv_heads"] * model_info["head_dim"] *
257
  model_info["num_layers"] * batch_size * kv_dtype_bytes
@@ -369,7 +384,7 @@ def create_memory_breakdown_chart(budget: dict) -> go.Figure:
369
  return fig
370
 
371
 
372
- def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: str = "FP16") -> go.Figure:
373
  """
374
  Create chart showing memory usage vs context length.
375
 
@@ -382,6 +397,9 @@ def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: st
382
  dtype_bytes = {"FP32": 4, "FP16": 2, "BF16": 2, "INT8": 1, "INT4": 0.5}.get(dtype, 2)
383
  model_weights_gb = (model_info["num_parameters"] * dtype_bytes) / (1024 ** 3)
384
 
 
 
 
385
  # Sequence lengths to plot
386
  seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]
387
 
@@ -395,6 +413,7 @@ def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: st
395
  num_layers=model_info["num_layers"],
396
  seq_len=seq_len,
397
  batch_size=batch_size,
 
398
  )
399
  kv_cache_values.append(kv["gb"])
400
  total_memory_values.append(model_weights_gb + kv["gb"])
@@ -439,7 +458,7 @@ def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: st
439
  fig.update_layout(
440
  title=dict(
441
  text=f"Memory Scaling: {model_name}<br>"
442
- f"<sub>batch={batch_size}, {dtype}, {model_info['num_kv_heads']} KV heads</sub>",
443
  x=0.5,
444
  ),
445
  xaxis_title="Context Length (tokens)",
@@ -466,7 +485,7 @@ def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: st
466
  return fig
467
 
468
 
469
- def create_batch_context_tradeoff_chart(model_name: str, dtype: str = "FP16") -> go.Figure:
470
  """
471
  Create chart showing batch size vs max context tradeoff.
472
 
@@ -487,11 +506,11 @@ def create_batch_context_tradeoff_chart(model_name: str, dtype: str = "FP16") ->
487
  )
488
  return fig
489
 
490
- batch_sizes = [1, 2, 4, 8, 16, 32]
491
  max_contexts = []
492
 
493
  for batch in batch_sizes:
494
- result = calculate_max_context_length(model_name, batch_size=batch, dtype=dtype)
495
  max_contexts.append(result.get("max_context", 0))
496
 
497
  fig = go.Figure()
@@ -517,7 +536,7 @@ def create_batch_context_tradeoff_chart(model_name: str, dtype: str = "FP16") ->
517
  fig.update_layout(
518
  title=dict(
519
  text=f"Batch Size vs Max Context: {model_name}<br>"
520
- f"<sub>GPU: {gpu_info['name']} ({gpu_info['total_memory_gb']:.1f} GB)</sub>",
521
  x=0.5,
522
  ),
523
  xaxis_title="Batch Size",
@@ -538,6 +557,7 @@ def run_memory_analysis(
538
  seq_len: int,
539
  batch_size: int = 1,
540
  dtype: str = "FP16",
 
541
  ) -> tuple:
542
  """
543
  Run complete memory analysis for a model.
@@ -545,15 +565,15 @@ def run_memory_analysis(
545
  Returns budget info, charts, and insight text.
546
  """
547
  # Calculate budget
548
- budget = calculate_memory_budget(model_name, seq_len, batch_size, dtype)
549
 
550
  # Calculate max context
551
- max_context = calculate_max_context_length(model_name, batch_size, dtype)
552
 
553
  # Create charts
554
  breakdown_chart = create_memory_breakdown_chart(budget)
555
- scaling_chart = create_context_scaling_chart(model_name, batch_size, dtype)
556
- tradeoff_chart = create_batch_context_tradeoff_chart(model_name, dtype)
557
 
558
  # Generate insight text
559
  model_info = budget["model_info"]
@@ -562,6 +582,10 @@ def run_memory_analysis(
562
 
563
  status_emoji = "✅" if budget["fits_in_gpu"] else "❌"
564
 
 
 
 
 
565
  insight = f"""### {model_name} Memory Analysis
566
 
567
  **Model Configuration (from model.config):**
@@ -578,12 +602,12 @@ def run_memory_analysis(
578
 
579
  ### Memory Breakdown at {seq_len:,} tokens (batch={batch_size})
580
 
581
- | Component | Memory |
582
- |-----------|--------|
583
- | Model Weights ({dtype}) | {breakdown['model_weights_gb']:.2f} GB |
584
- | KV Cache | {breakdown['kv_cache_gb']:.2f} GB |
585
- | Activations (est.) | {breakdown['activation_gb']:.2f} GB |
586
- | **Total** | **{breakdown['total_gb']:.2f} GB** |
587
 
588
  **GPU Utilization:** {budget['utilization_pct']:.1f}%
589
  **Headroom:** {budget['headroom_gb']:.2f} GB {status_emoji}
@@ -592,7 +616,7 @@ def run_memory_analysis(
592
 
593
  ### Maximum Context Length
594
 
595
- At batch size {batch_size} with {dtype}:
596
  - **Max context:** {max_context.get('max_context', 0):,} tokens
597
  - Available for KV cache: {max_context.get('available_for_kv_gb', 0):.2f} GB
598
  - KV cache per token: {max_context.get('bytes_per_token', 0):,} bytes
@@ -605,6 +629,7 @@ At batch size {batch_size} with {dtype}:
605
  - **GQA reduces** KV cache by {model_info['gqa_ratio']}× vs MHA
606
  - **Batch size trades off** with maximum context length
607
  - **{dtype} weights** use {breakdown['model_weights_gb']:.2f} GB
 
608
  """
609
 
610
  return budget, breakdown_chart, scaling_chart, tradeoff_chart, insight
 
129
  }
130
 
131
 
132
+ # KV Cache dtype bytes mapping
133
+ KV_DTYPE_BYTES = {
134
+ "FP16": 2,
135
+ "BF16": 2,
136
+ "FP8": 1,
137
+ "INT8": 1,
138
+ }
139
+
140
+
141
  def calculate_memory_budget(
142
  model_name: str,
143
  seq_len: int,
144
  batch_size: int = 1,
145
  dtype: str = "FP16",
146
+ kv_dtype: str = "FP16",
147
  ) -> dict:
148
  """
149
  Calculate complete memory budget using REAL model and GPU info.
 
153
  seq_len: Target sequence length
154
  batch_size: Batch size
155
  dtype: Data type for model weights
156
+ kv_dtype: Data type for KV cache (FP16, BF16, FP8, INT8)
157
 
158
  Returns:
159
  Complete memory budget breakdown
 
169
  model_weights_bytes = model_info["num_parameters"] * dtype_bytes
170
  model_weights_gb = model_weights_bytes / (1024 ** 3)
171
 
172
+ # KV cache memory - uses separate kv_dtype
173
+ kv_dtype_bytes = KV_DTYPE_BYTES.get(kv_dtype, 2)
174
  kv_cache = calculate_kv_cache_memory(
175
  num_kv_heads=model_info["num_kv_heads"],
176
  head_dim=model_info["head_dim"],
177
  num_layers=model_info["num_layers"],
178
  seq_len=seq_len,
179
  batch_size=batch_size,
180
+ dtype_bytes=kv_dtype_bytes,
181
  )
182
  kv_cache_gb = kv_cache["gb"]
183
 
 
199
  "model_info": model_info,
200
  "gpu_info": gpu_info,
201
  "dtype": dtype,
202
+ "kv_dtype": kv_dtype,
203
  "seq_len": seq_len,
204
  "batch_size": batch_size,
205
  "breakdown": {
 
219
  model_name: str,
220
  batch_size: int = 1,
221
  dtype: str = "FP16",
222
+ kv_dtype: str = "FP16",
223
  memory_reserve_pct: float = 10.0,
224
  ) -> dict:
225
  """
 
231
  model_name: Model to analyze
232
  batch_size: Batch size
233
  dtype: Data type for weights
234
+ kv_dtype: Data type for KV cache (FP16, BF16, FP8, INT8)
235
  memory_reserve_pct: Percentage to reserve for activations/overhead
236
 
237
  Returns:
 
266
 
267
  # Calculate max seq_len from available KV cache memory
268
  # KV cache bytes = 2 × num_kv_heads × head_dim × seq_len × num_layers × batch × dtype_bytes
269
+ kv_dtype_bytes = KV_DTYPE_BYTES.get(kv_dtype, 2)
270
  bytes_per_token = (
271
  2 * model_info["num_kv_heads"] * model_info["head_dim"] *
272
  model_info["num_layers"] * batch_size * kv_dtype_bytes
 
384
  return fig
385
 
386
 
387
+ def create_context_scaling_chart(model_name: str, batch_size: int = 1, dtype: str = "FP16", kv_dtype: str = "FP16") -> go.Figure:
388
  """
389
  Create chart showing memory usage vs context length.
390
 
 
397
  dtype_bytes = {"FP32": 4, "FP16": 2, "BF16": 2, "INT8": 1, "INT4": 0.5}.get(dtype, 2)
398
  model_weights_gb = (model_info["num_parameters"] * dtype_bytes) / (1024 ** 3)
399
 
400
+ # KV cache dtype bytes
401
+ kv_dtype_bytes = KV_DTYPE_BYTES.get(kv_dtype, 2)
402
+
403
  # Sequence lengths to plot
404
  seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]
405
 
 
413
  num_layers=model_info["num_layers"],
414
  seq_len=seq_len,
415
  batch_size=batch_size,
416
+ dtype_bytes=kv_dtype_bytes,
417
  )
418
  kv_cache_values.append(kv["gb"])
419
  total_memory_values.append(model_weights_gb + kv["gb"])
 
458
  fig.update_layout(
459
  title=dict(
460
  text=f"Memory Scaling: {model_name}<br>"
461
+ f"<sub>batch={batch_size}, Weights: {dtype}, KV Cache: {kv_dtype}</sub>",
462
  x=0.5,
463
  ),
464
  xaxis_title="Context Length (tokens)",
 
485
  return fig
486
 
487
 
488
+ def create_batch_context_tradeoff_chart(model_name: str, dtype: str = "FP16", kv_dtype: str = "FP16") -> go.Figure:
489
  """
490
  Create chart showing batch size vs max context tradeoff.
491
 
 
506
  )
507
  return fig
508
 
509
+ batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
510
  max_contexts = []
511
 
512
  for batch in batch_sizes:
513
+ result = calculate_max_context_length(model_name, batch_size=batch, dtype=dtype, kv_dtype=kv_dtype)
514
  max_contexts.append(result.get("max_context", 0))
515
 
516
  fig = go.Figure()
 
536
  fig.update_layout(
537
  title=dict(
538
  text=f"Batch Size vs Max Context: {model_name}<br>"
539
+ f"<sub>GPU: {gpu_info['name']}, Weights: {dtype}, KV: {kv_dtype}</sub>",
540
  x=0.5,
541
  ),
542
  xaxis_title="Batch Size",
 
557
  seq_len: int,
558
  batch_size: int = 1,
559
  dtype: str = "FP16",
560
+ kv_dtype: str = "FP16",
561
  ) -> tuple:
562
  """
563
  Run complete memory analysis for a model.
 
565
  Returns budget info, charts, and insight text.
566
  """
567
  # Calculate budget
568
+ budget = calculate_memory_budget(model_name, seq_len, batch_size, dtype, kv_dtype)
569
 
570
  # Calculate max context
571
+ max_context = calculate_max_context_length(model_name, batch_size, dtype, kv_dtype)
572
 
573
  # Create charts
574
  breakdown_chart = create_memory_breakdown_chart(budget)
575
+ scaling_chart = create_context_scaling_chart(model_name, batch_size, dtype, kv_dtype)
576
+ tradeoff_chart = create_batch_context_tradeoff_chart(model_name, dtype, kv_dtype)
577
 
578
  # Generate insight text
579
  model_info = budget["model_info"]
 
582
 
583
  status_emoji = "✅" if budget["fits_in_gpu"] else "❌"
584
 
585
+ # Calculate KV cache memory savings from quantization
586
+ kv_bytes = KV_DTYPE_BYTES.get(kv_dtype, 2)
587
+ kv_savings = (2 - kv_bytes) / 2 * 100 if kv_bytes < 2 else 0
588
+
589
  insight = f"""### {model_name} Memory Analysis
590
 
591
  **Model Configuration (from model.config):**
 
602
 
603
  ### Memory Breakdown at {seq_len:,} tokens (batch={batch_size})
604
 
605
+ | Component | Memory | Precision |
606
+ |-----------|--------|-----------|
607
+ | Model Weights | {breakdown['model_weights_gb']:.2f} GB | {dtype} |
608
+ | KV Cache | {breakdown['kv_cache_gb']:.2f} GB | {kv_dtype} |
609
+ | Activations (est.) | {breakdown['activation_gb']:.2f} GB | - |
610
+ | **Total** | **{breakdown['total_gb']:.2f} GB** | |
611
 
612
  **GPU Utilization:** {budget['utilization_pct']:.1f}%
613
  **Headroom:** {budget['headroom_gb']:.2f} GB {status_emoji}
 
616
 
617
  ### Maximum Context Length
618
 
619
+ At batch size {batch_size} with {dtype} weights, {kv_dtype} KV cache:
620
  - **Max context:** {max_context.get('max_context', 0):,} tokens
621
  - Available for KV cache: {max_context.get('available_for_kv_gb', 0):.2f} GB
622
  - KV cache per token: {max_context.get('bytes_per_token', 0):,} bytes
 
629
  - **GQA reduces** KV cache by {model_info['gqa_ratio']}× vs MHA
630
  - **Batch size trades off** with maximum context length
631
  - **{dtype} weights** use {breakdown['model_weights_gb']:.2f} GB
632
+ - **{kv_dtype} KV cache**{f" saves {kv_savings:.0f}% vs FP16" if kv_savings > 0 else " (baseline precision)"}
633
  """
634
 
635
  return budget, breakdown_chart, scaling_chart, tradeoff_chart, insight