Livengood Claude commited on
Commit
548f0fb
·
1 Parent(s): 26bc78c

Fix UI bugs: tabs, accordion, examples, and sliders

Browse files

Key fixes:
- Use gr.Tab instead of gr.TabItem for reliable tab switching
- Move all function definitions outside Blocks context
- Remove problematic hidden row pattern for search results
- Move pandas import to module level
- Simplify UI structure to avoid nesting issues
- Rename functions to avoid naming conflicts with components

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +297 -529
app.py CHANGED
@@ -16,6 +16,7 @@ Fetches model metadata from HF Hub and calculates:
16
  import gradio as gr
17
  from huggingface_hub import HfApi, hf_hub_download, list_models
18
  import json
 
19
  from functools import lru_cache
20
  from datetime import datetime
21
 
@@ -142,19 +143,13 @@ def estimate_params_from_safetensors(info) -> tuple[int, str]:
142
 
143
  def get_head_dim(config: dict) -> int:
144
  """Calculate head dimension from config, with fallbacks."""
145
- # Try to get it directly
146
  if "head_dim" in config:
147
  return config["head_dim"]
148
-
149
- # Calculate from hidden_size and num_attention_heads
150
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
151
  num_heads = config.get("num_attention_heads", config.get("n_head", 0))
152
-
153
  if hidden_size and num_heads:
154
  return hidden_size // num_heads
155
-
156
- # Common defaults by model family
157
- return 128 # Most common default
158
 
159
 
160
  def estimate_kv_cache_size(
@@ -165,44 +160,25 @@ def estimate_kv_cache_size(
165
  batch_size: int = 1,
166
  dtype_bytes: int = 2
167
  ) -> int:
168
- """
169
- KV cache size = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
170
-
171
- The 2 accounts for both K and V caches.
172
- """
173
- kv_cache_bytes = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
174
- return kv_cache_bytes
175
 
176
 
177
- def estimate_training_memory(
178
- param_count: int,
179
- dtype_bytes: int,
180
- optimizer: str = "AdamW"
181
- ) -> dict:
182
- """
183
- Estimate training memory requirements.
184
-
185
- For training, we need:
186
- - Model weights
187
- - Gradients (same size as weights)
188
- - Optimizer states (varies by optimizer)
189
- - Activations (highly variable, estimated)
190
- """
191
  weights_bytes = param_count * dtype_bytes
192
  gradients_bytes = param_count * dtype_bytes
193
 
194
- # Optimizer states
195
  if optimizer == "AdamW":
196
- # AdamW stores: m (momentum), v (variance) in FP32
197
- optimizer_bytes = param_count * 4 * 2 # 2 states, 4 bytes each
198
  elif optimizer == "SGD":
199
- optimizer_bytes = 0 # No extra state (momentum optional)
200
  elif optimizer == "SGD + Momentum":
201
- optimizer_bytes = param_count * 4 # Momentum buffer
202
  elif optimizer == "8-bit Adam":
203
- optimizer_bytes = param_count * 1 * 2 # 2 states, 1 byte each
204
  else:
205
- optimizer_bytes = param_count * 4 * 2 # Default to AdamW
206
 
207
  return {
208
  "weights": weights_bytes,
@@ -215,25 +191,22 @@ def estimate_training_memory(
215
  def calculate_multi_gpu_split(total_vram_gb: float, num_gpus: int, parallelism: str) -> dict:
216
  """Calculate memory distribution across multiple GPUs."""
217
  if parallelism == "Tensor Parallelism":
218
- # Weights and KV cache split evenly
219
  per_gpu = total_vram_gb / num_gpus
220
- overhead = 0.05 * total_vram_gb # Communication overhead
221
  return {
222
  "per_gpu": per_gpu + (overhead / num_gpus),
223
  "total": total_vram_gb + overhead,
224
  "efficiency": "High (best for inference)",
225
  }
226
  elif parallelism == "Pipeline Parallelism":
227
- # Layers distributed, but activation memory at boundaries
228
  per_gpu = total_vram_gb / num_gpus
229
- overhead = 0.1 * total_vram_gb # Activation memory overhead
230
  return {
231
  "per_gpu": per_gpu + (overhead / num_gpus),
232
  "total": total_vram_gb + overhead,
233
  "efficiency": "Medium (good for training)",
234
  }
235
- else: # Data Parallelism
236
- # Full model on each GPU
237
  return {
238
  "per_gpu": total_vram_gb,
239
  "total": total_vram_gb * num_gpus,
@@ -249,36 +222,18 @@ def estimate_lora_memory(
249
  target_modules: int = 4,
250
  use_qlora: bool = False
251
  ) -> dict:
252
- """
253
- Estimate LoRA/QLoRA fine-tuning memory requirements.
254
-
255
- LoRA adds low-rank adaptation matrices to specific layers.
256
- QLoRA additionally quantizes the base model to 4-bit.
257
- """
258
- # Base model weights
259
  if use_qlora:
260
- # QLoRA: 4-bit quantized weights
261
- base_weights_bytes = param_count * 0.5 # 4-bit = 0.5 bytes/param
262
  else:
263
  base_weights_bytes = param_count * dtype_bytes
264
 
265
- # LoRA adapter parameters (A and B matrices for each target module)
266
- # Typical target modules: q_proj, k_proj, v_proj, o_proj (4 modules)
267
- # Each LoRA layer: hidden_size * rank (A) + rank * hidden_size (B)
268
- # Approximate as 2 * hidden_size * rank per module
269
- # For simplicity, estimate based on total params
270
- lora_params_ratio = (lora_rank * 2 * target_modules) / 1000 # Rough estimate
271
- lora_params = int(param_count * lora_params_ratio * 0.01) # Usually ~0.1-1% of base
272
  lora_weights_bytes = lora_params * dtype_bytes
273
-
274
- # Gradients only for LoRA params (not frozen base)
275
  gradients_bytes = lora_params * dtype_bytes
276
-
277
- # Optimizer states for LoRA params only
278
- optimizer_bytes = lora_params * 4 * 2 # AdamW: 2 states, 4 bytes each
279
-
280
- # Activations (still needed, but can use gradient checkpointing)
281
- activation_bytes = base_weights_bytes * 0.5 # Reduced with checkpointing
282
 
283
  return {
284
  "base_weights": base_weights_bytes,
@@ -288,7 +243,7 @@ def estimate_lora_memory(
288
  "optimizer": optimizer_bytes,
289
  "activations": activation_bytes,
290
  "total": base_weights_bytes + lora_weights_bytes + gradients_bytes + optimizer_bytes + activation_bytes,
291
- "vs_full_finetune_ratio": 0.3 if use_qlora else 0.5, # Rough memory savings
292
  }
293
 
294
 
@@ -299,31 +254,18 @@ def estimate_throughput(
299
  context_length: int = 4096,
300
  is_prefill: bool = False
301
  ) -> dict:
302
- """
303
- Estimate tokens per second throughput.
304
-
305
- Based on roofline model: throughput limited by compute or memory bandwidth.
306
- Most LLM inference is memory-bound for single-batch decode.
307
- """
308
- # Rough estimate: 2 FLOPs per parameter per token (forward pass)
309
  flops_per_token = 2 * param_count
310
-
311
- # Peak theoretical throughput (compute-bound)
312
  peak_tokens_per_sec = (gpu_tflops * 1e12) / flops_per_token
313
-
314
- # Memory-bound estimate (more realistic for decode)
315
- # Assume ~1TB/s memory bandwidth for modern GPUs
316
- memory_bandwidth_tbs = 1.0 # TB/s, rough average
317
- bytes_per_token = param_count * 2 # FP16 weights need to be read
318
  memory_bound_tokens = (memory_bandwidth_tbs * 1e12) / bytes_per_token
319
 
320
- # Prefill is more compute-bound, decode is memory-bound
321
  if is_prefill:
322
  effective_tokens = min(peak_tokens_per_sec, memory_bound_tokens * 10) * batch_size
323
  else:
324
  effective_tokens = min(peak_tokens_per_sec, memory_bound_tokens) * batch_size
325
 
326
- # Apply realistic efficiency factor (typically 30-60% of theoretical)
327
  efficiency = 0.4
328
  realistic_tokens = effective_tokens * efficiency
329
 
@@ -336,14 +278,9 @@ def estimate_throughput(
336
  }
337
 
338
 
339
- def calculate_cost_estimate(
340
- vram_required: float,
341
- hours_per_day: float = 8,
342
- days_per_month: float = 22
343
- ) -> list:
344
  """Calculate cost estimates for cloud GPUs that fit the model."""
345
  estimates = []
346
-
347
  for gpu_name, (vram, instance, category, hourly_cost, tflops) in GPU_SPECS.items():
348
  if vram >= vram_required and hourly_cost > 0:
349
  daily_cost = hourly_cost * hours_per_day
@@ -356,21 +293,19 @@ def calculate_cost_estimate(
356
  "monthly": monthly_cost,
357
  "instance": instance,
358
  })
359
-
360
  return sorted(estimates, key=lambda x: x["hourly"])
361
 
362
 
363
- def search_models(query: str, limit: int = 10) -> list:
364
  """Search HuggingFace models by name."""
365
  if not query or len(query) < 2:
366
  return []
367
-
368
  try:
369
  models = list(list_models(
370
  search=query,
371
  sort="downloads",
372
  direction=-1,
373
- limit=limit,
374
  filter="text-generation"
375
  ))
376
  return [m.id for m in models]
@@ -378,27 +313,10 @@ def search_models(query: str, limit: int = 10) -> list:
378
  return []
379
 
380
 
381
- def calculate_flash_attention_savings(
382
- kv_cache_bytes: int,
383
- context_length: int
384
- ) -> dict:
385
- """
386
- Estimate memory savings from Flash Attention.
387
-
388
- Flash Attention uses tiling to reduce memory from O(n^2) to O(n).
389
- """
390
- # Standard attention materializes full attention matrix
391
- # Flash Attention streams through, never materializing full matrix
392
- # Savings primarily in activation memory, not KV cache
393
-
394
- # KV cache itself is O(n), so Flash Attention doesn't reduce it
395
- # But it dramatically reduces peak memory during computation
396
-
397
- # Estimate: Flash Attention reduces peak memory by avoiding
398
- # the O(n^2) attention matrix materialization
399
- standard_attention_overhead = context_length * context_length * 2 # FP16
400
- flash_attention_overhead = context_length * 128 * 2 # Block size overhead
401
-
402
  savings_bytes = standard_attention_overhead - flash_attention_overhead
403
  savings_ratio = 1 - (flash_attention_overhead / max(standard_attention_overhead, 1))
404
 
@@ -423,29 +341,24 @@ def calculate_vram(
423
  lora_rank: int = 16,
424
  show_throughput: bool = True,
425
  show_cost: bool = True
426
- ) -> tuple[str, dict | None]:
427
- """Main calculation function. Returns (markdown_results, chart_data)."""
428
-
429
- # Validate inputs
430
  model_id = model_id.strip()
431
  if not model_id:
432
  raise gr.Error("Please enter a model ID")
433
-
434
  if "/" not in model_id:
435
  raise gr.Error("Model ID should be in format 'organization/model-name'")
436
 
437
- # Fetch model info
438
  info = get_model_info(model_id)
439
  config = get_config(model_id)
440
 
441
  results = []
442
  results.append(f"## Model: [{model_id}](https://huggingface.co/{model_id})\n")
443
 
444
- # Get parameter count and dtype
445
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
446
 
447
  if param_count == 0:
448
- results.append("⚠️ Could not determine parameter count from safetensors metadata.\n")
449
  results.append("Model may use pytorch_model.bin or other format.\n")
450
  return "\n".join(results), None
451
 
@@ -456,13 +369,11 @@ def calculate_vram(
456
  results.append(f"**Dominant dtype:** {dominant_dtype} ({dtype_bytes} bytes/param)")
457
  results.append(f"**Mode:** {mode}")
458
 
459
- # Model weights VRAM
460
  weights_bytes = param_count * dtype_bytes
461
  weights_gb = bytes_to_gb(weights_bytes)
462
- results.append(f"\n### 📦 Weight Memory")
463
  results.append(f"Model weights: **{weights_gb:.2f} GB**")
464
 
465
- # Architecture details
466
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
467
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
468
  num_attention_heads = config.get("num_attention_heads", config.get("n_head", 0))
@@ -470,72 +381,61 @@ def calculate_vram(
470
  head_dim = get_head_dim(config)
471
  max_position = config.get("max_position_embeddings", config.get("n_positions", "N/A"))
472
 
473
- results.append(f"\n### 🏗️ Architecture (from config.json)")
474
  if "_error" in config:
475
- results.append(f"⚠️ Could not fetch config.json (model may be gated)")
476
  kv_gb = 0
477
  elif num_layers and hidden_size:
478
- results.append(f"- **Layers:** {num_layers}")
479
- results.append(f"- **Hidden size:** {hidden_size}")
480
- results.append(f"- **Attention heads:** {num_attention_heads}")
481
- results.append(f"- **KV heads:** {num_kv_heads} {'(GQA)' if num_kv_heads != num_attention_heads else '(MHA)'}")
482
- results.append(f"- **Head dimension:** {head_dim}")
483
- results.append(f"- **Max context:** {max_position:,}" if isinstance(max_position, int) else f"- **Max context:** {max_position}")
484
-
485
- # KV Cache calculation
486
- results.append(f"\n### 💾 KV Cache (batch_size={batch_size})")
 
 
487
  results.append("| Context | KV Cache | + Weights | Status |")
488
  results.append("|---------|----------|-----------|--------|")
489
 
490
- # Show relevant context lengths
491
  context_points = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
492
  for ctx_len in context_points:
493
  if ctx_len > context_length * 2 and ctx_len > 8192:
494
  break
495
- kv_bytes = estimate_kv_cache_size(
496
- num_layers, num_kv_heads, head_dim, ctx_len, batch_size, dtype_bytes
497
- )
498
  kv_gb_temp = bytes_to_gb(kv_bytes)
499
  total_temp = weights_gb + kv_gb_temp
500
- marker = " ** selected**" if ctx_len == context_length else ""
501
  results.append(f"| {ctx_len:,} | {kv_gb_temp:.2f} GB | {total_temp:.2f} GB |{marker} |")
502
 
503
- # Calculate for selected context
504
- kv_bytes = estimate_kv_cache_size(
505
- num_layers, num_kv_heads, head_dim, context_length, batch_size, dtype_bytes
506
- )
507
  kv_gb = bytes_to_gb(kv_bytes)
508
  else:
509
  results.append("Could not find architecture details")
510
  kv_gb = 0
511
 
512
- # Flash Attention savings
513
  flash_savings = None
514
  if use_flash_attention and kv_gb > 0:
515
- kv_bytes = estimate_kv_cache_size(
516
- num_layers, num_kv_heads, head_dim, context_length, batch_size, dtype_bytes
517
- )
518
  flash_savings = calculate_flash_attention_savings(kv_bytes, context_length)
519
 
520
- # Calculate total based on mode
521
  if mode == "Training (Full)":
522
  training_mem = estimate_training_memory(param_count, dtype_bytes, optimizer)
523
  base_gb = bytes_to_gb(training_mem["total_base"])
524
-
525
- # Activations estimation (rough: ~2x weights for typical batch)
526
  activation_gb = weights_gb * 2 * batch_size
527
  if use_flash_attention and flash_savings:
528
  activation_gb -= flash_savings["savings_gb"]
529
  activation_gb = max(0.1, activation_gb)
530
-
531
  total_gb = base_gb + kv_gb + activation_gb
532
 
533
- results.append(f"\n### 🎓 Training Memory Breakdown")
534
- results.append(f"- **Weights:** {weights_gb:.2f} GB")
535
- results.append(f"- **Gradients:** {bytes_to_gb(training_mem['gradients']):.2f} GB")
536
- results.append(f"- **Optimizer ({optimizer}):** {bytes_to_gb(training_mem['optimizer']):.2f} GB")
537
- results.append(f"- **KV Cache:** {kv_gb:.2f} GB")
538
- results.append(f"- **Activations (est.):** {activation_gb:.2f} GB")
539
 
540
  chart_data = {
541
  "Weights": weights_gb,
@@ -549,12 +449,12 @@ def calculate_vram(
549
  lora_mem = estimate_lora_memory(param_count, dtype_bytes, lora_rank, use_qlora=False)
550
  total_gb = bytes_to_gb(lora_mem["total"])
551
 
552
- results.append(f"\n### 🔧 LoRA Fine-tuning (rank={lora_rank})")
553
- results.append(f"- **Base weights (frozen):** {bytes_to_gb(lora_mem['base_weights']):.2f} GB")
554
- results.append(f"- **LoRA adapters:** {bytes_to_gb(lora_mem['lora_weights']):.3f} GB ({lora_mem['lora_params']:,} params)")
555
- results.append(f"- **Gradients (LoRA only):** {bytes_to_gb(lora_mem['gradients']):.3f} GB")
556
- results.append(f"- **Optimizer states:** {bytes_to_gb(lora_mem['optimizer']):.3f} GB")
557
- results.append(f"- **Activations:** {bytes_to_gb(lora_mem['activations']):.2f} GB")
558
  results.append(f"\n*Saves ~{(1-lora_mem['vs_full_finetune_ratio'])*100:.0f}% vs full fine-tuning*")
559
 
560
  chart_data = {
@@ -569,12 +469,12 @@ def calculate_vram(
569
  lora_mem = estimate_lora_memory(param_count, dtype_bytes, lora_rank, use_qlora=True)
570
  total_gb = bytes_to_gb(lora_mem["total"])
571
 
572
- results.append(f"\n### 🔧 QLoRA Fine-tuning (4-bit base, rank={lora_rank})")
573
- results.append(f"- **Base weights (4-bit):** {bytes_to_gb(lora_mem['base_weights']):.2f} GB")
574
- results.append(f"- **LoRA adapters:** {bytes_to_gb(lora_mem['lora_weights']):.3f} GB ({lora_mem['lora_params']:,} params)")
575
- results.append(f"- **Gradients (LoRA only):** {bytes_to_gb(lora_mem['gradients']):.3f} GB")
576
- results.append(f"- **Optimizer states:** {bytes_to_gb(lora_mem['optimizer']):.3f} GB")
577
- results.append(f"- **Activations:** {bytes_to_gb(lora_mem['activations']):.2f} GB")
578
  results.append(f"\n*Saves ~{(1-lora_mem['vs_full_finetune_ratio'])*100:.0f}% vs full fine-tuning*")
579
 
580
  chart_data = {
@@ -586,22 +486,18 @@ def calculate_vram(
586
  }
587
 
588
  else:
589
- # Inference mode
590
  framework_overhead = SERVING_FRAMEWORKS.get(serving_framework, 1.15)
591
  base_total = weights_gb + kv_gb
592
  overhead_gb = base_total * (framework_overhead - 1)
593
-
594
- # Flash Attention reduces activation memory overhead during inference
595
  if use_flash_attention and flash_savings:
596
  overhead_gb -= min(flash_savings["savings_gb"] * 0.1, overhead_gb * 0.5)
597
  overhead_gb = max(0, overhead_gb)
598
-
599
  total_gb = base_total + overhead_gb
600
 
601
- results.append(f"\n### Inference Memory ({serving_framework})")
602
- results.append(f"- **Weights:** {weights_gb:.2f} GB")
603
- results.append(f"- **KV Cache:** {kv_gb:.2f} GB")
604
- results.append(f"- **Framework overhead:** {overhead_gb:.2f} GB ({(framework_overhead-1)*100:.0f}%)")
605
 
606
  chart_data = {
607
  "Weights": weights_gb,
@@ -609,100 +505,76 @@ def calculate_vram(
609
  "Overhead": overhead_gb,
610
  }
611
 
612
- # Flash Attention info
613
  if use_flash_attention and flash_savings and flash_savings["savings_gb"] > 0.01:
614
- results.append(f"\n### Flash Attention")
615
- results.append(f"- **Enabled:** Yes")
616
- results.append(f"- **Peak memory savings:** ~{flash_savings['savings_gb']:.2f} GB ({flash_savings['savings_percent']:.1f}%)")
617
 
618
- results.append(f"\n### 📊 Total VRAM Required: **{total_gb:.2f} GB**")
619
 
620
- # Multi-GPU calculations
621
  if num_gpus > 1:
622
  multi_gpu = calculate_multi_gpu_split(total_gb, num_gpus, parallelism)
623
- results.append(f"\n### 🔗 Multi-GPU ({num_gpus}x GPUs, {parallelism})")
624
- results.append(f"- **Per GPU:** {multi_gpu['per_gpu']:.2f} GB")
625
- results.append(f"- **Total across GPUs:** {multi_gpu['total']:.2f} GB")
626
- results.append(f"- **Efficiency:** {multi_gpu['efficiency']}")
627
-
628
- # Update total for GPU recommendations
629
  effective_vram_needed = multi_gpu['per_gpu']
630
  else:
631
  effective_vram_needed = total_gb
632
 
633
- # GPU Recommendations
634
- results.append(f"\n### 🎮 GPU Recommendations")
635
  results.append("| GPU | VRAM | Fits? | Headroom | Est. tok/s | Instance |")
636
  results.append("|-----|------|-------|----------|------------|----------|")
637
 
638
  for gpu_name, (vram, instance, category, hourly_cost, tflops) in GPU_SPECS.items():
639
- fits = "" if vram >= effective_vram_needed else ""
640
  headroom = vram - effective_vram_needed
641
  headroom_str = f"+{headroom:.1f} GB" if headroom > 0 else f"{headroom:.1f} GB"
642
-
643
- # Estimate throughput for this GPU
644
  if show_throughput and vram >= effective_vram_needed:
645
  throughput = estimate_throughput(param_count, tflops, batch_size, context_length)
646
  tok_str = f"~{throughput['estimated_tokens_per_sec']:.0f}"
647
  else:
648
  tok_str = "-"
649
-
650
  results.append(f"| {gpu_name} | {vram} GB | {fits} | {headroom_str} | {tok_str} | {instance} |")
651
 
652
- # Quantization options (if model doesn't fit on consumer GPUs)
653
  if effective_vram_needed > 24:
654
- results.append(f"\n### 🗜️ Quantization Options")
655
- results.append("To fit on consumer GPUs (24 GB), consider these options:\n")
656
  results.append("| Method | Est. Size | Quality | Notes |")
657
  results.append("|--------|-----------|---------|-------|")
658
-
659
  for method, specs in QUANTIZATION_METHODS.items():
660
  quant_size = bytes_to_gb(param_count * specs["bytes_per_param"])
661
- quant_with_overhead = quant_size * 1.1 # Small overhead
662
- fits = "" if quant_with_overhead <= 24 else ""
663
- results.append(f"| {method} | {quant_with_overhead:.1f} GB | {specs['quality']} | {fits} {specs['desc']} |")
664
-
665
- results.append(f"\n**Tip:** Search for `{model_id.split('/')[-1]} GGUF` or `{model_id.split('/')[-1]} AWQ` on HuggingFace.")
666
 
667
- # Cost estimates for cloud GPUs
668
  if show_cost:
669
  cost_estimates = calculate_cost_estimate(effective_vram_needed)
670
  if cost_estimates:
671
- results.append(f"\n### 💰 Cloud Cost Estimates")
672
  results.append("*Based on 8 hrs/day, 22 days/month*\n")
673
  results.append("| GPU | Hourly | Daily | Monthly |")
674
  results.append("|-----|--------|-------|---------|")
675
- for est in cost_estimates[:5]: # Top 5 cheapest
676
  results.append(f"| {est['gpu']} | ${est['hourly']:.2f} | ${est['daily']:.2f} | ${est['monthly']:.0f} |")
677
 
678
- return "\n".join(results), chart_data
679
-
 
 
 
680
 
681
- def create_memory_chart(chart_data: dict | None):
682
- """Create a bar chart for memory breakdown."""
683
- if not chart_data:
684
- return None
685
 
686
- labels = list(chart_data.keys())
687
- values = list(chart_data.values())
688
 
689
- return gr.BarPlot(
690
- value={"Component": labels, "GB": values},
691
- x="Component",
692
- y="GB",
693
- title="Memory Breakdown",
694
- height=300,
695
- width=400,
696
- )
697
-
698
-
699
- def compare_models(model_ids_text: str, context_length: int = 4096) -> str:
700
  """Compare multiple models side by side."""
701
  model_ids = [m.strip() for m in model_ids_text.split("\n") if m.strip()]
702
 
703
  if len(model_ids) < 2:
704
  return "Please enter at least 2 model IDs (one per line)"
705
-
706
  if len(model_ids) > 5:
707
  return "Maximum 5 models for comparison"
708
 
@@ -716,32 +588,23 @@ def compare_models(model_ids_text: str, context_length: int = 4096) -> str:
716
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
717
 
718
  if param_count == 0:
719
- comparison_data.append({
720
- "model": model_id,
721
- "params": "N/A",
722
- "error": "Could not determine parameters"
723
- })
724
  continue
725
 
726
  dtype_bytes = DTYPE_BYTES.get(dominant_dtype, 2)
727
  weights_gb = bytes_to_gb(param_count * dtype_bytes)
728
 
729
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
730
- num_kv_heads = config.get("num_key_value_heads",
731
- config.get("num_attention_heads", 0))
732
  head_dim = get_head_dim(config)
733
 
734
- kv_bytes = estimate_kv_cache_size(
735
- num_layers, num_kv_heads, head_dim, context_length, 1, dtype_bytes
736
- )
737
  kv_gb = bytes_to_gb(kv_bytes)
738
  total_inference = weights_gb + kv_gb
739
 
740
- # Training estimate
741
  training_mem = estimate_training_memory(param_count, dtype_bytes)
742
  training_gb = bytes_to_gb(training_mem["total_base"]) + weights_gb * 2
743
 
744
- # QLoRA estimate
745
  qlora_mem = estimate_lora_memory(param_count, dtype_bytes, 16, use_qlora=True)
746
  qlora_gb = bytes_to_gb(qlora_mem["total"])
747
 
@@ -749,20 +612,13 @@ def compare_models(model_ids_text: str, context_length: int = 4096) -> str:
749
  "model": model_id.split("/")[-1],
750
  "full_id": model_id,
751
  "params": f"{param_count/1e9:.1f}B",
752
- "dtype": dominant_dtype,
753
- "weights_gb": weights_gb,
754
- "kv_gb": kv_gb,
755
  "inference_gb": total_inference,
756
  "training_gb": training_gb,
757
  "qlora_gb": qlora_gb,
758
  })
759
  except Exception as e:
760
- comparison_data.append({
761
- "model": model_id,
762
- "error": str(e)
763
- })
764
 
765
- # Build comparison table
766
  results.append(f"*Context length: {context_length:,}*\n")
767
  results.append("| Model | Params | Inference | Training | QLoRA |")
768
  results.append("|-------|--------|-----------|----------|-------|")
@@ -773,21 +629,16 @@ def compare_models(model_ids_text: str, context_length: int = 4096) -> str:
773
  else:
774
  results.append(
775
  f"| [{data['model']}](https://huggingface.co/{data['full_id']}) | "
776
- f"{data['params']} | "
777
- f"{data['inference_gb']:.1f} GB | "
778
- f"{data['training_gb']:.1f} GB | "
779
- f"{data['qlora_gb']:.1f} GB |"
780
  )
781
 
782
- # Find minimum for each category
783
  valid_data = [d for d in comparison_data if "error" not in d]
784
  if len(valid_data) >= 2:
785
  results.append("\n### Recommendations")
786
-
787
  min_inference = min(valid_data, key=lambda x: x["inference_gb"])
788
  min_training = min(valid_data, key=lambda x: x["training_gb"])
789
  min_qlora = min(valid_data, key=lambda x: x["qlora_gb"])
790
-
791
  results.append(f"- **Best for inference:** {min_inference['model']} ({min_inference['inference_gb']:.1f} GB)")
792
  results.append(f"- **Best for training:** {min_training['model']} ({min_training['training_gb']:.1f} GB)")
793
  results.append(f"- **Best for QLoRA:** {min_qlora['model']} ({min_qlora['qlora_gb']:.1f} GB)")
@@ -795,7 +646,7 @@ def compare_models(model_ids_text: str, context_length: int = 4096) -> str:
795
  return "\n".join(results)
796
 
797
 
798
- def export_results(result_text: str, format_type: str) -> str:
799
  """Export results to different formats."""
800
  if not result_text:
801
  return "No results to export. Run a calculation first."
@@ -803,15 +654,8 @@ def export_results(result_text: str, format_type: str) -> str:
803
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
804
 
805
  if format_type == "JSON":
806
- # Parse markdown to create structured JSON
807
- import re
808
  lines = result_text.split("\n")
809
- data = {
810
- "timestamp": timestamp,
811
- "raw_markdown": result_text,
812
- "sections": {}
813
- }
814
-
815
  current_section = "header"
816
  for line in lines:
817
  if line.startswith("### "):
@@ -821,298 +665,222 @@ def export_results(result_text: str, format_type: str) -> str:
821
  if current_section not in data["sections"]:
822
  data["sections"][current_section] = []
823
  data["sections"][current_section].append(line.strip())
824
-
825
  return json.dumps(data, indent=2)
 
 
 
826
 
827
- else: # Plain text
828
- # Convert markdown to plain text
829
- plain = result_text
830
- plain = plain.replace("**", "")
831
- plain = plain.replace("###", "\n===")
832
- plain = plain.replace("##", "\n===")
833
- plain = f"VRAM Calculator Export - {timestamp}\n{'='*50}\n\n{plain}"
834
- return plain
835
 
 
 
 
 
 
 
 
 
836
 
837
- # Build Gradio interface
838
- with gr.Blocks(title="VRAM Calculator", theme=gr.themes.Soft()) as demo:
839
- gr.Markdown("""
840
- # VRAM & Instance Type Calculator
841
 
842
- Estimate GPU memory requirements for HuggingFace models. Supports inference, training, LoRA/QLoRA fine-tuning,
843
- multi-GPU setups, model comparison, and detailed quantization recommendations.
844
- """)
 
 
845
 
846
- with gr.Tabs():
847
- # === CALCULATOR TAB ===
848
- with gr.TabItem("Calculator"):
849
- with gr.Row():
850
- with gr.Column(scale=2):
851
- model_input = gr.Textbox(
852
- label="Model ID",
853
- placeholder="meta-llama/Llama-3.1-8B",
854
- info="Full HuggingFace model ID (org/model-name)"
855
- )
856
- with gr.Column(scale=1):
857
- search_input = gr.Textbox(
858
- label="Search Models",
859
- placeholder="llama 8b",
860
- info="Search HuggingFace for models"
861
- )
862
- search_btn = gr.Button("Search", size="sm")
863
-
864
- with gr.Row(visible=False) as search_results_row:
865
- search_results = gr.Dropdown(
866
- label="Search Results (click to select)",
867
- choices=[],
868
- interactive=True,
869
- )
870
 
871
- def do_search(query):
872
- if not query:
873
- return gr.update(visible=False), gr.update(choices=[])
874
- results = search_models(query, limit=10)
875
- if results:
876
- return gr.update(visible=True), gr.update(choices=results, value=results[0])
877
- return gr.update(visible=True), gr.update(choices=["No models found"], value=None)
878
-
879
- def select_model(selected):
880
- if selected and selected != "No models found":
881
- return selected
882
- return ""
883
-
884
- search_btn.click(
885
- fn=do_search,
886
- inputs=[search_input],
887
- outputs=[search_results_row, search_results]
888
  )
889
- search_results.change(
890
- fn=select_model,
891
- inputs=[search_results],
892
- outputs=[model_input]
 
893
  )
894
 
895
- with gr.Row():
896
- with gr.Column(scale=1):
897
- mode_input = gr.Radio(
898
- choices=["Inference", "Training (Full)", "LoRA Fine-tuning", "QLoRA Fine-tuning"],
899
- value="Inference",
900
- label="Mode",
901
- info="LoRA/QLoRA use significantly less memory"
902
- )
903
- with gr.Column(scale=1):
904
- context_input = gr.Slider(
905
- label="Context Length",
906
- minimum=512,
907
- maximum=131072,
908
- value=4096,
909
- step=512,
910
- info="Sequence length for KV cache"
911
- )
912
- with gr.Column(scale=1):
913
- batch_input = gr.Slider(
914
- label="Batch Size",
915
- minimum=1,
916
- maximum=64,
917
- value=1,
918
- step=1,
919
- info="Concurrent sequences"
920
- )
921
-
922
- with gr.Accordion("Advanced Options", open=False):
923
- with gr.Row():
924
- with gr.Column():
925
- serving_input = gr.Dropdown(
926
- choices=list(SERVING_FRAMEWORKS.keys()),
927
- value="None (raw PyTorch)",
928
- label="Serving Framework",
929
- info="Different frameworks have different overhead"
930
- )
931
- optimizer_input = gr.Dropdown(
932
- choices=["AdamW", "SGD", "SGD + Momentum", "8-bit Adam"],
933
- value="AdamW",
934
- label="Optimizer (Training mode)",
935
- info="Optimizer state memory varies"
936
- )
937
- lora_rank_input = gr.Slider(
938
- label="LoRA Rank",
939
- minimum=4,
940
- maximum=128,
941
- value=16,
942
- step=4,
943
- info="Higher rank = more capacity but more memory"
944
- )
945
- with gr.Column():
946
- num_gpus_input = gr.Slider(
947
- label="Number of GPUs",
948
- minimum=1,
949
- maximum=8,
950
- value=1,
951
- step=1,
952
- info="For multi-GPU setups"
953
- )
954
- parallelism_input = gr.Dropdown(
955
- choices=["Tensor Parallelism", "Pipeline Parallelism", "Data Parallelism"],
956
- value="Tensor Parallelism",
957
- label="Parallelism Strategy",
958
- info="How to distribute across GPUs"
959
- )
960
- flash_attention_input = gr.Checkbox(
961
- label="Use Flash Attention",
962
- value=True,
963
- info="Reduces peak memory usage"
964
- )
965
- with gr.Row():
966
- show_throughput_input = gr.Checkbox(
967
- label="Show Throughput Estimates",
968
- value=True,
969
- info="Estimated tokens/sec per GPU"
970
- )
971
- show_cost_input = gr.Checkbox(
972
- label="Show Cost Estimates",
973
- value=True,
974
- info="Cloud GPU hourly/monthly costs"
975
- )
976
-
977
- calculate_btn = gr.Button("Calculate VRAM", variant="primary", size="lg")
978
-
979
- with gr.Row():
980
- with gr.Column(scale=3):
981
- output = gr.Markdown(label="Results")
982
- with gr.Column(scale=1):
983
- chart_output = gr.BarPlot(
984
- x="Component",
985
- y="GB",
986
- title="Memory Breakdown",
987
- height=350,
988
- )
989
-
990
- def run_calculation(
991
- model_id, context_length, batch_size, mode, optimizer, serving,
992
- num_gpus, parallelism, flash_attention, lora_rank, show_throughput, show_cost
993
- ):
994
- result_text, chart_data = calculate_vram(
995
- model_id, context_length, batch_size, mode, optimizer, serving,
996
- num_gpus, parallelism, flash_attention, lora_rank, show_throughput, show_cost
997
- )
998
- if chart_data:
999
- import pandas as pd
1000
- df = pd.DataFrame({
1001
- "Component": list(chart_data.keys()),
1002
- "GB": list(chart_data.values())
1003
- })
1004
- return result_text, df
1005
- return result_text, None
1006
-
1007
- calculate_btn.click(
1008
- fn=run_calculation,
1009
- inputs=[
1010
- model_input, context_input, batch_input, mode_input,
1011
- optimizer_input, serving_input, num_gpus_input, parallelism_input,
1012
- flash_attention_input, lora_rank_input, show_throughput_input, show_cost_input
1013
- ],
1014
- outputs=[output, chart_output]
1015
  )
1016
 
1017
- # Examples
1018
- gr.Examples(
1019
- examples=[
1020
- ["meta-llama/Llama-3.1-8B", 4096, 1],
1021
- ["meta-llama/Llama-3.1-70B", 8192, 1],
1022
- ["mistralai/Mistral-7B-v0.1", 8192, 1],
1023
- ["Qwen/Qwen2.5-72B", 32768, 1],
1024
- ["google/gemma-2-27b", 8192, 1],
1025
- ["microsoft/phi-4", 16384, 1],
1026
- ["deepseek-ai/DeepSeek-V3", 4096, 1],
1027
- ["meta-llama/Llama-3.3-70B-Instruct", 8192, 1],
1028
- ],
1029
- inputs=[model_input, context_input, batch_input],
1030
- label="Popular Models"
1031
- )
1032
 
1033
- # === COMPARE TAB ===
1034
- with gr.TabItem("Compare Models"):
1035
- gr.Markdown("""
1036
- Compare VRAM requirements across multiple models side-by-side.
1037
- Enter model IDs one per line (2-5 models).
1038
- """)
1039
-
1040
- compare_models_input = gr.Textbox(
1041
- label="Model IDs (one per line)",
1042
- placeholder="meta-llama/Llama-3.1-8B\nmistralai/Mistral-7B-v0.1\nQwen/Qwen2.5-7B",
1043
- lines=5,
1044
  )
1045
- compare_context_input = gr.Slider(
1046
  label="Context Length",
1047
  minimum=512,
1048
  maximum=131072,
1049
  value=4096,
1050
- step=512,
1051
  )
1052
- compare_btn = gr.Button("Compare Models", variant="primary")
1053
- compare_output = gr.Markdown(label="Comparison Results")
1054
-
1055
- compare_btn.click(
1056
- fn=compare_models,
1057
- inputs=[compare_models_input, compare_context_input],
1058
- outputs=compare_output
1059
  )
1060
 
1061
- gr.Examples(
1062
- examples=[
1063
- ["meta-llama/Llama-3.1-8B\nmistralai/Mistral-7B-v0.1\nQwen/Qwen2.5-7B", 4096],
1064
- ["meta-llama/Llama-3.1-70B\nQwen/Qwen2.5-72B\nmeta-llama/Llama-3.3-70B-Instruct", 8192],
1065
- ],
1066
- inputs=[compare_models_input, compare_context_input],
1067
- label="Example Comparisons"
1068
- )
 
 
 
 
 
 
 
 
 
 
 
1069
 
1070
- # === EXPORT TAB ===
1071
- with gr.TabItem("Export"):
1072
- gr.Markdown("""
1073
- Export your calculation results to JSON or plain text format.
1074
- First run a calculation in the Calculator tab, then copy the results here.
1075
- """)
1076
-
1077
- export_input = gr.Textbox(
1078
- label="Paste Results Here",
1079
- placeholder="Paste the calculation results from the Calculator tab...",
1080
- lines=10,
1081
- )
1082
- export_format = gr.Radio(
1083
- choices=["JSON", "Plain Text"],
1084
- value="JSON",
1085
- label="Export Format"
1086
- )
1087
- export_btn = gr.Button("Export", variant="primary")
1088
- export_output = gr.Textbox(
1089
- label="Exported Data",
1090
- lines=15,
1091
- show_copy_button=True,
1092
- )
1093
 
1094
- export_btn.click(
1095
- fn=export_results,
1096
- inputs=[export_input, export_format],
1097
- outputs=export_output
 
 
 
 
 
 
 
 
 
 
1098
  )
1099
 
1100
- # Notes outside tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1101
  gr.Markdown("""
1102
  ---
1103
- ### Notes
1104
- - **Inference mode:** Weights + KV cache + framework overhead
1105
- - **Training modes:** Full training, LoRA, and QLoRA with different memory profiles
1106
- - **KV cache:** Scales linearly with context length and batch size
1107
- - **Multi-GPU:** Tensor parallelism splits memory; data parallelism replicates it
1108
- - **Quantization:** GGUF/AWQ/GPTQ can reduce memory 2-8x with minimal quality loss
1109
-
1110
- ### Disclaimers
1111
- - Estimates are approximate; actual usage varies by implementation
1112
- - Flash Attention and other optimizations can reduce peak memory
1113
- - Throughput estimates assume ideal conditions
1114
-
1115
- Built with Gradio & HuggingFace Hub API
1116
  """)
1117
 
1118
 
 
16
  import gradio as gr
17
  from huggingface_hub import HfApi, hf_hub_download, list_models
18
  import json
19
+ import pandas as pd
20
  from functools import lru_cache
21
  from datetime import datetime
22
 
 
143
 
144
  def get_head_dim(config: dict) -> int:
145
  """Calculate head dimension from config, with fallbacks."""
 
146
  if "head_dim" in config:
147
  return config["head_dim"]
 
 
148
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
149
  num_heads = config.get("num_attention_heads", config.get("n_head", 0))
 
150
  if hidden_size and num_heads:
151
  return hidden_size // num_heads
152
+ return 128
 
 
153
 
154
 
155
  def estimate_kv_cache_size(
 
160
  batch_size: int = 1,
161
  dtype_bytes: int = 2
162
  ) -> int:
163
+ """KV cache size = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes"""
164
+ return 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
 
 
 
 
 
165
 
166
 
167
+ def estimate_training_memory(param_count: int, dtype_bytes: int, optimizer: str = "AdamW") -> dict:
168
+ """Estimate training memory requirements."""
 
 
 
 
 
 
 
 
 
 
 
 
169
  weights_bytes = param_count * dtype_bytes
170
  gradients_bytes = param_count * dtype_bytes
171
 
 
172
  if optimizer == "AdamW":
173
+ optimizer_bytes = param_count * 4 * 2
 
174
  elif optimizer == "SGD":
175
+ optimizer_bytes = 0
176
  elif optimizer == "SGD + Momentum":
177
+ optimizer_bytes = param_count * 4
178
  elif optimizer == "8-bit Adam":
179
+ optimizer_bytes = param_count * 1 * 2
180
  else:
181
+ optimizer_bytes = param_count * 4 * 2
182
 
183
  return {
184
  "weights": weights_bytes,
 
191
  def calculate_multi_gpu_split(total_vram_gb: float, num_gpus: int, parallelism: str) -> dict:
192
  """Calculate memory distribution across multiple GPUs."""
193
  if parallelism == "Tensor Parallelism":
 
194
  per_gpu = total_vram_gb / num_gpus
195
+ overhead = 0.05 * total_vram_gb
196
  return {
197
  "per_gpu": per_gpu + (overhead / num_gpus),
198
  "total": total_vram_gb + overhead,
199
  "efficiency": "High (best for inference)",
200
  }
201
  elif parallelism == "Pipeline Parallelism":
 
202
  per_gpu = total_vram_gb / num_gpus
203
+ overhead = 0.1 * total_vram_gb
204
  return {
205
  "per_gpu": per_gpu + (overhead / num_gpus),
206
  "total": total_vram_gb + overhead,
207
  "efficiency": "Medium (good for training)",
208
  }
209
+ else:
 
210
  return {
211
  "per_gpu": total_vram_gb,
212
  "total": total_vram_gb * num_gpus,
 
222
  target_modules: int = 4,
223
  use_qlora: bool = False
224
  ) -> dict:
225
+ """Estimate LoRA/QLoRA fine-tuning memory requirements."""
 
 
 
 
 
 
226
  if use_qlora:
227
+ base_weights_bytes = param_count * 0.5
 
228
  else:
229
  base_weights_bytes = param_count * dtype_bytes
230
 
231
+ lora_params_ratio = (lora_rank * 2 * target_modules) / 1000
232
+ lora_params = int(param_count * lora_params_ratio * 0.01)
 
 
 
 
 
233
  lora_weights_bytes = lora_params * dtype_bytes
 
 
234
  gradients_bytes = lora_params * dtype_bytes
235
+ optimizer_bytes = lora_params * 4 * 2
236
+ activation_bytes = base_weights_bytes * 0.5
 
 
 
 
237
 
238
  return {
239
  "base_weights": base_weights_bytes,
 
243
  "optimizer": optimizer_bytes,
244
  "activations": activation_bytes,
245
  "total": base_weights_bytes + lora_weights_bytes + gradients_bytes + optimizer_bytes + activation_bytes,
246
+ "vs_full_finetune_ratio": 0.3 if use_qlora else 0.5,
247
  }
248
 
249
 
 
254
  context_length: int = 4096,
255
  is_prefill: bool = False
256
  ) -> dict:
257
+ """Estimate tokens per second throughput."""
 
 
 
 
 
 
258
  flops_per_token = 2 * param_count
 
 
259
  peak_tokens_per_sec = (gpu_tflops * 1e12) / flops_per_token
260
+ memory_bandwidth_tbs = 1.0
261
+ bytes_per_token = param_count * 2
 
 
 
262
  memory_bound_tokens = (memory_bandwidth_tbs * 1e12) / bytes_per_token
263
 
 
264
  if is_prefill:
265
  effective_tokens = min(peak_tokens_per_sec, memory_bound_tokens * 10) * batch_size
266
  else:
267
  effective_tokens = min(peak_tokens_per_sec, memory_bound_tokens) * batch_size
268
 
 
269
  efficiency = 0.4
270
  realistic_tokens = effective_tokens * efficiency
271
 
 
278
  }
279
 
280
 
281
+ def calculate_cost_estimate(vram_required: float, hours_per_day: float = 8, days_per_month: float = 22) -> list:
 
 
 
 
282
  """Calculate cost estimates for cloud GPUs that fit the model."""
283
  estimates = []
 
284
  for gpu_name, (vram, instance, category, hourly_cost, tflops) in GPU_SPECS.items():
285
  if vram >= vram_required and hourly_cost > 0:
286
  daily_cost = hourly_cost * hours_per_day
 
293
  "monthly": monthly_cost,
294
  "instance": instance,
295
  })
 
296
  return sorted(estimates, key=lambda x: x["hourly"])
297
 
298
 
299
+ def search_models_fn(query: str) -> list:
300
  """Search HuggingFace models by name."""
301
  if not query or len(query) < 2:
302
  return []
 
303
  try:
304
  models = list(list_models(
305
  search=query,
306
  sort="downloads",
307
  direction=-1,
308
+ limit=10,
309
  filter="text-generation"
310
  ))
311
  return [m.id for m in models]
 
313
  return []
314
 
315
 
316
+ def calculate_flash_attention_savings(kv_cache_bytes: int, context_length: int) -> dict:
317
+ """Estimate memory savings from Flash Attention."""
318
+ standard_attention_overhead = context_length * context_length * 2
319
+ flash_attention_overhead = context_length * 128 * 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  savings_bytes = standard_attention_overhead - flash_attention_overhead
321
  savings_ratio = 1 - (flash_attention_overhead / max(standard_attention_overhead, 1))
322
 
 
341
  lora_rank: int = 16,
342
  show_throughput: bool = True,
343
  show_cost: bool = True
344
+ ):
345
+ """Main calculation function. Returns (markdown_results, chart_dataframe)."""
 
 
346
  model_id = model_id.strip()
347
  if not model_id:
348
  raise gr.Error("Please enter a model ID")
 
349
  if "/" not in model_id:
350
  raise gr.Error("Model ID should be in format 'organization/model-name'")
351
 
 
352
  info = get_model_info(model_id)
353
  config = get_config(model_id)
354
 
355
  results = []
356
  results.append(f"## Model: [{model_id}](https://huggingface.co/{model_id})\n")
357
 
 
358
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
359
 
360
  if param_count == 0:
361
+ results.append("Could not determine parameter count from safetensors metadata.\n")
362
  results.append("Model may use pytorch_model.bin or other format.\n")
363
  return "\n".join(results), None
364
 
 
369
  results.append(f"**Dominant dtype:** {dominant_dtype} ({dtype_bytes} bytes/param)")
370
  results.append(f"**Mode:** {mode}")
371
 
 
372
  weights_bytes = param_count * dtype_bytes
373
  weights_gb = bytes_to_gb(weights_bytes)
374
+ results.append(f"\n### Weight Memory")
375
  results.append(f"Model weights: **{weights_gb:.2f} GB**")
376
 
 
377
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
378
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
379
  num_attention_heads = config.get("num_attention_heads", config.get("n_head", 0))
 
381
  head_dim = get_head_dim(config)
382
  max_position = config.get("max_position_embeddings", config.get("n_positions", "N/A"))
383
 
384
+ results.append(f"\n### Architecture")
385
  if "_error" in config:
386
+ results.append(f"Could not fetch config.json (model may be gated)")
387
  kv_gb = 0
388
  elif num_layers and hidden_size:
389
+ results.append(f"- Layers: {num_layers}")
390
+ results.append(f"- Hidden size: {hidden_size}")
391
+ results.append(f"- Attention heads: {num_attention_heads}")
392
+ results.append(f"- KV heads: {num_kv_heads} {'(GQA)' if num_kv_heads != num_attention_heads else '(MHA)'}")
393
+ results.append(f"- Head dimension: {head_dim}")
394
+ if isinstance(max_position, int):
395
+ results.append(f"- Max context: {max_position:,}")
396
+ else:
397
+ results.append(f"- Max context: {max_position}")
398
+
399
+ results.append(f"\n### KV Cache (batch_size={batch_size})")
400
  results.append("| Context | KV Cache | + Weights | Status |")
401
  results.append("|---------|----------|-----------|--------|")
402
 
 
403
  context_points = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
404
  for ctx_len in context_points:
405
  if ctx_len > context_length * 2 and ctx_len > 8192:
406
  break
407
+ kv_bytes = estimate_kv_cache_size(num_layers, num_kv_heads, head_dim, ctx_len, batch_size, dtype_bytes)
 
 
408
  kv_gb_temp = bytes_to_gb(kv_bytes)
409
  total_temp = weights_gb + kv_gb_temp
410
+ marker = " **<- selected**" if ctx_len == context_length else ""
411
  results.append(f"| {ctx_len:,} | {kv_gb_temp:.2f} GB | {total_temp:.2f} GB |{marker} |")
412
 
413
+ kv_bytes = estimate_kv_cache_size(num_layers, num_kv_heads, head_dim, context_length, batch_size, dtype_bytes)
 
 
 
414
  kv_gb = bytes_to_gb(kv_bytes)
415
  else:
416
  results.append("Could not find architecture details")
417
  kv_gb = 0
418
 
 
419
  flash_savings = None
420
  if use_flash_attention and kv_gb > 0:
421
+ kv_bytes = estimate_kv_cache_size(num_layers, num_kv_heads, head_dim, context_length, batch_size, dtype_bytes)
 
 
422
  flash_savings = calculate_flash_attention_savings(kv_bytes, context_length)
423
 
 
424
  if mode == "Training (Full)":
425
  training_mem = estimate_training_memory(param_count, dtype_bytes, optimizer)
426
  base_gb = bytes_to_gb(training_mem["total_base"])
 
 
427
  activation_gb = weights_gb * 2 * batch_size
428
  if use_flash_attention and flash_savings:
429
  activation_gb -= flash_savings["savings_gb"]
430
  activation_gb = max(0.1, activation_gb)
 
431
  total_gb = base_gb + kv_gb + activation_gb
432
 
433
+ results.append(f"\n### Training Memory Breakdown")
434
+ results.append(f"- Weights: {weights_gb:.2f} GB")
435
+ results.append(f"- Gradients: {bytes_to_gb(training_mem['gradients']):.2f} GB")
436
+ results.append(f"- Optimizer ({optimizer}): {bytes_to_gb(training_mem['optimizer']):.2f} GB")
437
+ results.append(f"- KV Cache: {kv_gb:.2f} GB")
438
+ results.append(f"- Activations (est.): {activation_gb:.2f} GB")
439
 
440
  chart_data = {
441
  "Weights": weights_gb,
 
449
  lora_mem = estimate_lora_memory(param_count, dtype_bytes, lora_rank, use_qlora=False)
450
  total_gb = bytes_to_gb(lora_mem["total"])
451
 
452
+ results.append(f"\n### LoRA Fine-tuning (rank={lora_rank})")
453
+ results.append(f"- Base weights (frozen): {bytes_to_gb(lora_mem['base_weights']):.2f} GB")
454
+ results.append(f"- LoRA adapters: {bytes_to_gb(lora_mem['lora_weights']):.3f} GB ({lora_mem['lora_params']:,} params)")
455
+ results.append(f"- Gradients (LoRA only): {bytes_to_gb(lora_mem['gradients']):.3f} GB")
456
+ results.append(f"- Optimizer states: {bytes_to_gb(lora_mem['optimizer']):.3f} GB")
457
+ results.append(f"- Activations: {bytes_to_gb(lora_mem['activations']):.2f} GB")
458
  results.append(f"\n*Saves ~{(1-lora_mem['vs_full_finetune_ratio'])*100:.0f}% vs full fine-tuning*")
459
 
460
  chart_data = {
 
469
  lora_mem = estimate_lora_memory(param_count, dtype_bytes, lora_rank, use_qlora=True)
470
  total_gb = bytes_to_gb(lora_mem["total"])
471
 
472
+ results.append(f"\n### QLoRA Fine-tuning (4-bit base, rank={lora_rank})")
473
+ results.append(f"- Base weights (4-bit): {bytes_to_gb(lora_mem['base_weights']):.2f} GB")
474
+ results.append(f"- LoRA adapters: {bytes_to_gb(lora_mem['lora_weights']):.3f} GB ({lora_mem['lora_params']:,} params)")
475
+ results.append(f"- Gradients (LoRA only): {bytes_to_gb(lora_mem['gradients']):.3f} GB")
476
+ results.append(f"- Optimizer states: {bytes_to_gb(lora_mem['optimizer']):.3f} GB")
477
+ results.append(f"- Activations: {bytes_to_gb(lora_mem['activations']):.2f} GB")
478
  results.append(f"\n*Saves ~{(1-lora_mem['vs_full_finetune_ratio'])*100:.0f}% vs full fine-tuning*")
479
 
480
  chart_data = {
 
486
  }
487
 
488
  else:
 
489
  framework_overhead = SERVING_FRAMEWORKS.get(serving_framework, 1.15)
490
  base_total = weights_gb + kv_gb
491
  overhead_gb = base_total * (framework_overhead - 1)
 
 
492
  if use_flash_attention and flash_savings:
493
  overhead_gb -= min(flash_savings["savings_gb"] * 0.1, overhead_gb * 0.5)
494
  overhead_gb = max(0, overhead_gb)
 
495
  total_gb = base_total + overhead_gb
496
 
497
+ results.append(f"\n### Inference Memory ({serving_framework})")
498
+ results.append(f"- Weights: {weights_gb:.2f} GB")
499
+ results.append(f"- KV Cache: {kv_gb:.2f} GB")
500
+ results.append(f"- Framework overhead: {overhead_gb:.2f} GB ({(framework_overhead-1)*100:.0f}%)")
501
 
502
  chart_data = {
503
  "Weights": weights_gb,
 
505
  "Overhead": overhead_gb,
506
  }
507
 
 
508
  if use_flash_attention and flash_savings and flash_savings["savings_gb"] > 0.01:
509
+ results.append(f"\n### Flash Attention")
510
+ results.append(f"- Enabled: Yes")
511
+ results.append(f"- Peak memory savings: ~{flash_savings['savings_gb']:.2f} GB ({flash_savings['savings_percent']:.1f}%)")
512
 
513
+ results.append(f"\n### Total VRAM Required: **{total_gb:.2f} GB**")
514
 
 
515
  if num_gpus > 1:
516
  multi_gpu = calculate_multi_gpu_split(total_gb, num_gpus, parallelism)
517
+ results.append(f"\n### Multi-GPU ({num_gpus}x GPUs, {parallelism})")
518
+ results.append(f"- Per GPU: {multi_gpu['per_gpu']:.2f} GB")
519
+ results.append(f"- Total across GPUs: {multi_gpu['total']:.2f} GB")
520
+ results.append(f"- Efficiency: {multi_gpu['efficiency']}")
 
 
521
  effective_vram_needed = multi_gpu['per_gpu']
522
  else:
523
  effective_vram_needed = total_gb
524
 
525
+ results.append(f"\n### GPU Recommendations")
 
526
  results.append("| GPU | VRAM | Fits? | Headroom | Est. tok/s | Instance |")
527
  results.append("|-----|------|-------|----------|------------|----------|")
528
 
529
  for gpu_name, (vram, instance, category, hourly_cost, tflops) in GPU_SPECS.items():
530
+ fits = "Yes" if vram >= effective_vram_needed else "No"
531
  headroom = vram - effective_vram_needed
532
  headroom_str = f"+{headroom:.1f} GB" if headroom > 0 else f"{headroom:.1f} GB"
 
 
533
  if show_throughput and vram >= effective_vram_needed:
534
  throughput = estimate_throughput(param_count, tflops, batch_size, context_length)
535
  tok_str = f"~{throughput['estimated_tokens_per_sec']:.0f}"
536
  else:
537
  tok_str = "-"
 
538
  results.append(f"| {gpu_name} | {vram} GB | {fits} | {headroom_str} | {tok_str} | {instance} |")
539
 
 
540
  if effective_vram_needed > 24:
541
+ results.append(f"\n### Quantization Options")
542
+ results.append("To fit on consumer GPUs (24 GB or less), consider:\n")
543
  results.append("| Method | Est. Size | Quality | Notes |")
544
  results.append("|--------|-----------|---------|-------|")
 
545
  for method, specs in QUANTIZATION_METHODS.items():
546
  quant_size = bytes_to_gb(param_count * specs["bytes_per_param"])
547
+ quant_with_overhead = quant_size * 1.1
548
+ fits = "Yes" if quant_with_overhead <= 24 else "No"
549
+ results.append(f"| {method} | {quant_with_overhead:.1f} GB | {specs['quality']} | {fits} - {specs['desc']} |")
550
+ model_name = model_id.split('/')[-1]
551
+ results.append(f"\n**Tip:** Search for `{model_name} GGUF` or `{model_name} AWQ` on HuggingFace.")
552
 
 
553
  if show_cost:
554
  cost_estimates = calculate_cost_estimate(effective_vram_needed)
555
  if cost_estimates:
556
+ results.append(f"\n### Cloud Cost Estimates")
557
  results.append("*Based on 8 hrs/day, 22 days/month*\n")
558
  results.append("| GPU | Hourly | Daily | Monthly |")
559
  results.append("|-----|--------|-------|---------|")
560
+ for est in cost_estimates[:5]:
561
  results.append(f"| {est['gpu']} | ${est['hourly']:.2f} | ${est['daily']:.2f} | ${est['monthly']:.0f} |")
562
 
563
+ # Create DataFrame for chart
564
+ df = pd.DataFrame({
565
+ "Component": list(chart_data.keys()),
566
+ "GB": list(chart_data.values())
567
+ })
568
 
569
+ return "\n".join(results), df
 
 
 
570
 
 
 
571
 
572
+ def compare_models_fn(model_ids_text: str, context_length: int = 4096) -> str:
 
 
 
 
 
 
 
 
 
 
573
  """Compare multiple models side by side."""
574
  model_ids = [m.strip() for m in model_ids_text.split("\n") if m.strip()]
575
 
576
  if len(model_ids) < 2:
577
  return "Please enter at least 2 model IDs (one per line)"
 
578
  if len(model_ids) > 5:
579
  return "Maximum 5 models for comparison"
580
 
 
588
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
589
 
590
  if param_count == 0:
591
+ comparison_data.append({"model": model_id, "error": "Could not determine parameters"})
 
 
 
 
592
  continue
593
 
594
  dtype_bytes = DTYPE_BYTES.get(dominant_dtype, 2)
595
  weights_gb = bytes_to_gb(param_count * dtype_bytes)
596
 
597
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
598
+ num_kv_heads = config.get("num_key_value_heads", config.get("num_attention_heads", 0))
 
599
  head_dim = get_head_dim(config)
600
 
601
+ kv_bytes = estimate_kv_cache_size(num_layers, num_kv_heads, head_dim, context_length, 1, dtype_bytes)
 
 
602
  kv_gb = bytes_to_gb(kv_bytes)
603
  total_inference = weights_gb + kv_gb
604
 
 
605
  training_mem = estimate_training_memory(param_count, dtype_bytes)
606
  training_gb = bytes_to_gb(training_mem["total_base"]) + weights_gb * 2
607
 
 
608
  qlora_mem = estimate_lora_memory(param_count, dtype_bytes, 16, use_qlora=True)
609
  qlora_gb = bytes_to_gb(qlora_mem["total"])
610
 
 
612
  "model": model_id.split("/")[-1],
613
  "full_id": model_id,
614
  "params": f"{param_count/1e9:.1f}B",
 
 
 
615
  "inference_gb": total_inference,
616
  "training_gb": training_gb,
617
  "qlora_gb": qlora_gb,
618
  })
619
  except Exception as e:
620
+ comparison_data.append({"model": model_id, "error": str(e)})
 
 
 
621
 
 
622
  results.append(f"*Context length: {context_length:,}*\n")
623
  results.append("| Model | Params | Inference | Training | QLoRA |")
624
  results.append("|-------|--------|-----------|----------|-------|")
 
629
  else:
630
  results.append(
631
  f"| [{data['model']}](https://huggingface.co/{data['full_id']}) | "
632
+ f"{data['params']} | {data['inference_gb']:.1f} GB | "
633
+ f"{data['training_gb']:.1f} GB | {data['qlora_gb']:.1f} GB |"
 
 
634
  )
635
 
 
636
  valid_data = [d for d in comparison_data if "error" not in d]
637
  if len(valid_data) >= 2:
638
  results.append("\n### Recommendations")
 
639
  min_inference = min(valid_data, key=lambda x: x["inference_gb"])
640
  min_training = min(valid_data, key=lambda x: x["training_gb"])
641
  min_qlora = min(valid_data, key=lambda x: x["qlora_gb"])
 
642
  results.append(f"- **Best for inference:** {min_inference['model']} ({min_inference['inference_gb']:.1f} GB)")
643
  results.append(f"- **Best for training:** {min_training['model']} ({min_training['training_gb']:.1f} GB)")
644
  results.append(f"- **Best for QLoRA:** {min_qlora['model']} ({min_qlora['qlora_gb']:.1f} GB)")
 
646
  return "\n".join(results)
647
 
648
 
649
+ def export_results_fn(result_text: str, format_type: str) -> str:
650
  """Export results to different formats."""
651
  if not result_text:
652
  return "No results to export. Run a calculation first."
 
654
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
655
 
656
  if format_type == "JSON":
 
 
657
  lines = result_text.split("\n")
658
+ data = {"timestamp": timestamp, "raw_markdown": result_text, "sections": {}}
 
 
 
 
 
659
  current_section = "header"
660
  for line in lines:
661
  if line.startswith("### "):
 
665
  if current_section not in data["sections"]:
666
  data["sections"][current_section] = []
667
  data["sections"][current_section].append(line.strip())
 
668
  return json.dumps(data, indent=2)
669
+ else:
670
+ plain = result_text.replace("**", "").replace("###", "\n===").replace("##", "\n===")
671
+ return f"VRAM Calculator Export - {timestamp}\n{'='*50}\n\n{plain}"
672
 
 
 
 
 
 
 
 
 
673
 
674
+ def do_search(query: str):
675
+ """Search for models and return dropdown choices."""
676
+ if not query:
677
+ return gr.update(choices=[], value=None)
678
+ results = search_models_fn(query)
679
+ if results:
680
+ return gr.update(choices=results, value=results[0])
681
+ return gr.update(choices=["No models found"], value=None)
682
 
 
 
 
 
683
 
684
+ def select_from_search(selected: str) -> str:
685
+ """Select a model from search results."""
686
+ if selected and selected != "No models found":
687
+ return selected
688
+ return ""
689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
+ # Build Gradio interface
692
+ with gr.Blocks(title="VRAM Calculator", theme=gr.themes.Soft()) as demo:
693
+ gr.Markdown("# VRAM & Instance Type Calculator")
694
+ gr.Markdown("Estimate GPU memory requirements for HuggingFace models.")
695
+
696
+ with gr.Tab("Calculator"):
697
+ with gr.Row():
698
+ model_input = gr.Textbox(
699
+ label="Model ID",
700
+ placeholder="meta-llama/Llama-3.1-8B",
701
+ info="Full HuggingFace model ID (org/model-name)",
702
+ scale=2
 
 
 
 
 
703
  )
704
+ search_input = gr.Textbox(
705
+ label="Search Models",
706
+ placeholder="llama 8b",
707
+ info="Search HuggingFace",
708
+ scale=1
709
  )
710
 
711
+ with gr.Row():
712
+ search_btn = gr.Button("Search HuggingFace", scale=1)
713
+ search_results = gr.Dropdown(
714
+ label="Search Results",
715
+ choices=[],
716
+ interactive=True,
717
+ scale=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  )
719
 
720
+ search_btn.click(fn=do_search, inputs=[search_input], outputs=[search_results])
721
+ search_results.change(fn=select_from_search, inputs=[search_results], outputs=[model_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
+ with gr.Row():
724
+ mode_input = gr.Radio(
725
+ choices=["Inference", "Training (Full)", "LoRA Fine-tuning", "QLoRA Fine-tuning"],
726
+ value="Inference",
727
+ label="Mode"
 
 
 
 
 
 
728
  )
729
+ context_input = gr.Slider(
730
  label="Context Length",
731
  minimum=512,
732
  maximum=131072,
733
  value=4096,
734
+ step=512
735
  )
736
+ batch_input = gr.Slider(
737
+ label="Batch Size",
738
+ minimum=1,
739
+ maximum=64,
740
+ value=1,
741
+ step=1
 
742
  )
743
 
744
+ with gr.Accordion("Advanced Options", open=False):
745
+ with gr.Row():
746
+ serving_input = gr.Dropdown(
747
+ choices=list(SERVING_FRAMEWORKS.keys()),
748
+ value="None (raw PyTorch)",
749
+ label="Serving Framework"
750
+ )
751
+ optimizer_input = gr.Dropdown(
752
+ choices=["AdamW", "SGD", "SGD + Momentum", "8-bit Adam"],
753
+ value="AdamW",
754
+ label="Optimizer (Training mode)"
755
+ )
756
+ lora_rank_input = gr.Slider(
757
+ label="LoRA Rank",
758
+ minimum=4,
759
+ maximum=128,
760
+ value=16,
761
+ step=4
762
+ )
763
 
764
+ with gr.Row():
765
+ num_gpus_input = gr.Slider(
766
+ label="Number of GPUs",
767
+ minimum=1,
768
+ maximum=8,
769
+ value=1,
770
+ step=1
771
+ )
772
+ parallelism_input = gr.Dropdown(
773
+ choices=["Tensor Parallelism", "Pipeline Parallelism", "Data Parallelism"],
774
+ value="Tensor Parallelism",
775
+ label="Parallelism Strategy"
776
+ )
777
+ flash_attention_input = gr.Checkbox(
778
+ label="Use Flash Attention",
779
+ value=True
780
+ )
 
 
 
 
 
 
781
 
782
+ with gr.Row():
783
+ show_throughput_input = gr.Checkbox(label="Show Throughput Estimates", value=True)
784
+ show_cost_input = gr.Checkbox(label="Show Cost Estimates", value=True)
785
+
786
+ calculate_btn = gr.Button("Calculate VRAM", variant="primary", size="lg")
787
+
788
+ with gr.Row():
789
+ output = gr.Markdown(label="Results")
790
+ chart_output = gr.BarPlot(
791
+ x="Component",
792
+ y="GB",
793
+ title="Memory Breakdown",
794
+ height=350,
795
+ width=400
796
  )
797
 
798
+ calculate_btn.click(
799
+ fn=calculate_vram,
800
+ inputs=[
801
+ model_input, context_input, batch_input, mode_input,
802
+ optimizer_input, serving_input, num_gpus_input, parallelism_input,
803
+ flash_attention_input, lora_rank_input, show_throughput_input, show_cost_input
804
+ ],
805
+ outputs=[output, chart_output]
806
+ )
807
+
808
+ gr.Markdown("### Popular Models")
809
+ gr.Examples(
810
+ examples=[
811
+ ["meta-llama/Llama-3.1-8B", 4096, 1],
812
+ ["meta-llama/Llama-3.1-70B", 8192, 1],
813
+ ["mistralai/Mistral-7B-v0.1", 8192, 1],
814
+ ["Qwen/Qwen2.5-72B", 32768, 1],
815
+ ["google/gemma-2-27b", 8192, 1],
816
+ ["microsoft/phi-4", 16384, 1],
817
+ ],
818
+ inputs=[model_input, context_input, batch_input],
819
+ )
820
+
821
+ with gr.Tab("Compare Models"):
822
+ gr.Markdown("Compare VRAM requirements across multiple models. Enter model IDs one per line (2-5 models).")
823
+
824
+ compare_models_input = gr.Textbox(
825
+ label="Model IDs (one per line)",
826
+ placeholder="meta-llama/Llama-3.1-8B\nmistralai/Mistral-7B-v0.1\nQwen/Qwen2.5-7B",
827
+ lines=5,
828
+ )
829
+ compare_context_input = gr.Slider(
830
+ label="Context Length",
831
+ minimum=512,
832
+ maximum=131072,
833
+ value=4096,
834
+ step=512,
835
+ )
836
+ compare_btn = gr.Button("Compare Models", variant="primary")
837
+ compare_output = gr.Markdown(label="Comparison Results")
838
+
839
+ compare_btn.click(
840
+ fn=compare_models_fn,
841
+ inputs=[compare_models_input, compare_context_input],
842
+ outputs=[compare_output]
843
+ )
844
+
845
+ gr.Markdown("### Example Comparisons")
846
+ gr.Examples(
847
+ examples=[
848
+ ["meta-llama/Llama-3.1-8B\nmistralai/Mistral-7B-v0.1\nQwen/Qwen2.5-7B", 4096],
849
+ ["meta-llama/Llama-3.1-70B\nQwen/Qwen2.5-72B", 8192],
850
+ ],
851
+ inputs=[compare_models_input, compare_context_input],
852
+ )
853
+
854
+ with gr.Tab("Export"):
855
+ gr.Markdown("Export calculation results to JSON or plain text. Copy results from Calculator tab.")
856
+
857
+ export_input = gr.Textbox(
858
+ label="Paste Results Here",
859
+ placeholder="Paste the calculation results...",
860
+ lines=10,
861
+ )
862
+ export_format = gr.Radio(
863
+ choices=["JSON", "Plain Text"],
864
+ value="JSON",
865
+ label="Export Format"
866
+ )
867
+ export_btn = gr.Button("Export", variant="primary")
868
+ export_output = gr.Textbox(
869
+ label="Exported Data",
870
+ lines=15,
871
+ show_copy_button=True,
872
+ )
873
+
874
+ export_btn.click(
875
+ fn=export_results_fn,
876
+ inputs=[export_input, export_format],
877
+ outputs=[export_output]
878
+ )
879
+
880
  gr.Markdown("""
881
  ---
882
+ **Notes:** Estimates are approximate. Flash Attention and other optimizations can reduce peak memory.
883
+ Throughput estimates assume ideal conditions. Built with Gradio & HuggingFace Hub API.
 
 
 
 
 
 
 
 
 
 
 
884
  """)
885
 
886