Livengood Claude commited on
Commit
50bc6be
·
1 Parent(s): 6f1d86c

Add training mode, multi-GPU support, expanded GPU database, quantization breakdown, and visual memory chart

Browse files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +436 -134
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,28 +2,45 @@
2
  VRAM & Instance Type Calculator for HuggingFace Models
3
 
4
  Fetches model metadata from HF Hub and calculates:
5
- - Minimum VRAM required for inference
6
  - KV cache requirements at various context lengths
7
  - Recommended GPUs and cloud instances
 
 
8
  """
9
 
10
  import gradio as gr
11
  from huggingface_hub import HfApi, hf_hub_download
12
  import json
13
- import math
14
 
15
  # Initialize HF API client
16
  api = HfApi()
17
 
18
- # GPU specs: name -> (VRAM in GB, typical cloud instance)
19
  GPU_SPECS = {
20
- "RTX 3090": (24, "Consumer"),
21
- "RTX 4090": (24, "Consumer"),
22
- "A10G": (24, "AWS g5.xlarge (~$1/hr)"),
23
- "L4": (24, "GCP g2-standard-4 (~$0.70/hr)"),
24
- "A100 40GB": (40, "AWS p4d.24xlarge, GCP a2-highgpu-1g"),
25
- "A100 80GB": (80, "AWS p4de.24xlarge, GCP a2-ultragpu-1g"),
26
- "H100 80GB": (80, "AWS p5.48xlarge, GCP a3-highgpu-8g"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  }
28
 
29
  # Bytes per element for different dtypes
@@ -38,36 +55,79 @@ DTYPE_BYTES = {
38
  "I64": 8, "int64": 8,
39
  }
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def bytes_to_gb(b: int) -> float:
 
43
  return b / (1024 ** 3)
44
 
45
 
46
- def get_model_info(model_id: str) -> dict:
47
- """Fetch model info from HF Hub."""
 
 
 
 
 
48
  try:
49
  info = api.model_info(model_id, files_metadata=True)
50
  return info
51
  except Exception as e:
52
- raise gr.Error(f"Could not fetch model info: {e}")
53
 
54
 
55
- def get_config(model_id: str) -> dict:
56
- """Try to fetch config.json for architecture details."""
 
57
  try:
58
  config_path = hf_hub_download(model_id, "config.json")
59
  with open(config_path) as f:
60
- return json.load(f)
61
  except Exception as e:
62
- # Gated models or missing config
63
- return {"_error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def estimate_params_from_safetensors(info) -> tuple[int, str]:
67
  """Extract parameter count and dtype from safetensors metadata."""
68
- if info.safetensors:
69
  param_count = info.safetensors.total
70
- # Get the dominant dtype
71
  params_by_dtype = info.safetensors.parameters
72
  if params_by_dtype:
73
  dominant_dtype = max(params_by_dtype, key=params_by_dtype.get)
@@ -75,152 +135,329 @@ def estimate_params_from_safetensors(info) -> tuple[int, str]:
75
  return 0, "F16"
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def estimate_kv_cache_size(
79
  num_layers: int,
80
- hidden_size: int,
81
  num_kv_heads: int,
 
82
  context_length: int,
83
  batch_size: int = 1,
84
  dtype_bytes: int = 2
85
  ) -> int:
86
  """
87
  KV cache size = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
88
- head_dim = hidden_size / num_attention_heads (but we use hidden_size / num_kv_heads for GQA)
 
89
  """
90
- # For GQA models, KV cache uses num_kv_heads, not num_attention_heads
91
- # head_dim is typically hidden_size / num_attention_heads
92
- # But KV cache stores: num_kv_heads * head_dim per layer
93
- # Simplified: 2 * layers * batch * seq * hidden_size * (num_kv_heads / num_attn_heads) * dtype
94
- # For non-GQA: num_kv_heads == num_attn_heads, so it's just 2 * layers * batch * seq * hidden
95
-
96
- # More accurate: 2 (K+V) * layers * batch * seq * num_kv_heads * head_dim
97
- # We'll estimate head_dim as hidden_size / num_kv_heads if we don't know num_attn_heads
98
- # This is a rough estimate
99
-
100
- head_dim = 128 # Common default (Llama, Mistral, etc.)
101
  kv_cache_bytes = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
102
  return kv_cache_bytes
103
 
104
 
105
- def calculate_vram(model_id: str, context_length: int = 4096, batch_size: int = 1) -> str:
106
- """Main calculation function."""
107
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Fetch model info
109
  info = get_model_info(model_id)
110
  config = get_config(model_id)
111
-
112
  results = []
113
  results.append(f"## Model: [{model_id}](https://huggingface.co/{model_id})\n")
114
-
115
  # Get parameter count and dtype
116
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
117
-
118
  if param_count == 0:
119
- # Fallback: try to infer from model name or config
120
  results.append("⚠️ Could not determine parameter count from safetensors metadata.\n")
121
  results.append("Model may use pytorch_model.bin or other format.\n")
122
- return "\n".join(results)
123
-
124
  dtype_bytes = DTYPE_BYTES.get(dominant_dtype, 2)
125
  params_b = param_count / 1e9
126
-
127
- results.append(f"**Parameters:** {params_b:.2f}B")
128
- results.append(f"**Dominant dtype:** {dominant_dtype} ({dtype_bytes} bytes)")
129
-
 
130
  # Model weights VRAM
131
  weights_bytes = param_count * dtype_bytes
132
  weights_gb = bytes_to_gb(weights_bytes)
133
- results.append(f"\n### Weight Memory")
134
  results.append(f"Model weights: **{weights_gb:.2f} GB**")
135
-
136
- # KV Cache estimation (if we have config)
137
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
138
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
139
- num_kv_heads = config.get("num_key_value_heads", config.get("num_attention_heads", config.get("n_head", 0)))
140
-
141
- results.append(f"\n### Architecture (from config.json)")
 
 
 
142
  if "_error" in config:
143
- results.append(f"⚠️ Could not fetch config.json (model may be gated or config missing)")
144
- results.append("KV cache calculation skipped - using weight-only estimate with 20% overhead")
145
  elif num_layers and hidden_size:
146
- results.append(f"Layers: {num_layers}, Hidden size: {hidden_size}, KV Heads: {num_kv_heads}")
147
-
148
- # Calculate KV cache for different context lengths
149
- results.append(f"\n### KV Cache (batch_size={batch_size})")
150
- results.append("| Context Length | KV Cache | Total VRAM |")
151
- results.append("|----------------|----------|------------|")
152
-
153
- for ctx_len in [2048, 4096, 8192, 16384, 32768, 65536, 131072]:
154
- if ctx_len > context_length * 2:
 
 
 
 
 
 
 
155
  break
156
  kv_bytes = estimate_kv_cache_size(
157
- num_layers, hidden_size, num_kv_heads, ctx_len, batch_size, dtype_bytes
158
  )
159
- kv_gb = bytes_to_gb(kv_bytes)
160
- total_gb = weights_gb + kv_gb
161
- marker = " selected" if ctx_len == context_length else ""
162
- results.append(f"| {ctx_len:,} | {kv_gb:.2f} GB | **{total_gb:.2f} GB**{marker} |")
163
- else:
164
- results.append("Could not find architecture details in config.json")
165
-
166
- # Calculate for user's selected context length
167
- if num_layers and hidden_size and num_kv_heads:
168
  kv_bytes = estimate_kv_cache_size(
169
- num_layers, hidden_size, num_kv_heads, context_length, batch_size, dtype_bytes
170
  )
171
  kv_gb = bytes_to_gb(kv_bytes)
172
- total_inference_gb = weights_gb + kv_gb
173
  else:
174
- total_inference_gb = weights_gb * 1.2 # 20% overhead estimate
175
-
176
- # Add activation memory overhead (~10-20%)
177
- total_with_overhead = total_inference_gb * 1.15
178
-
179
- results.append(f"\n### Total VRAM Estimate")
180
- results.append(f"Weights + KV Cache + Overhead (~15%): **{total_with_overhead:.2f} GB**")
181
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # GPU Recommendations
183
- results.append(f"\n### Recommended GPUs")
184
- results.append("| GPU | VRAM | Fits? | Cloud Instance |")
185
- results.append("|-----|------|-------|----------------|")
186
-
187
- for gpu_name, (vram, instance) in GPU_SPECS.items():
188
- fits = "✅" if vram >= total_with_overhead else "❌"
189
- headroom = vram - total_with_overhead
190
- headroom_str = f"+{headroom:.1f}GB" if headroom > 0 else f"{headroom:.1f}GB"
191
- results.append(f"| {gpu_name} | {vram}GB | {fits} ({headroom_str}) | {instance} |")
192
-
193
- # Quantization suggestions
194
- if total_with_overhead > 24:
195
- results.append(f"\n### 💡 Quantization Options")
196
- results.append("To fit on consumer GPUs (24GB), consider:")
197
-
198
- q8_estimate = (param_count * 1) / (1024**3) * 1.15
199
- q4_estimate = (param_count * 0.5) / (1024**3) * 1.15
200
-
201
- results.append(f"- **INT8 quantization:** ~{q8_estimate:.1f} GB")
202
- results.append(f"- **INT4 quantization:** ~{q4_estimate:.1f} GB")
203
- results.append(f"\nLook for GGUF or AWQ versions of this model on HF Hub.")
204
-
205
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  # Build Gradio interface
209
- with gr.Blocks(title="VRAM Calculator") as demo:
210
  gr.Markdown("""
211
  # 🧮 VRAM & Instance Type Calculator
212
-
213
- Enter a HuggingFace model ID to estimate VRAM requirements and get GPU/cloud instance recommendations.
214
-
215
- **How it works:** Fetches model metadata (safetensors info, config.json) to calculate memory for weights + KV cache.
216
  """)
217
-
218
  with gr.Row():
219
  with gr.Column(scale=2):
220
  model_input = gr.Textbox(
221
  label="Model ID",
222
  placeholder="meta-llama/Llama-3.1-8B",
223
- info="Enter the full HuggingFace model ID (e.g., 'mistralai/Mistral-7B-v0.1')"
 
 
 
 
 
 
 
 
 
224
  )
225
  with gr.Column(scale=1):
226
  context_input = gr.Slider(
@@ -229,50 +466,115 @@ with gr.Blocks(title="VRAM Calculator") as demo:
229
  maximum=131072,
230
  value=4096,
231
  step=512,
232
- info="Max sequence length for KV cache calculation"
233
  )
234
  with gr.Column(scale=1):
235
  batch_input = gr.Slider(
236
  label="Batch Size",
237
  minimum=1,
238
- maximum=32,
239
  value=1,
240
  step=1,
241
  info="Concurrent sequences"
242
  )
243
-
244
- calculate_btn = gr.Button("Calculate VRAM", variant="primary", size="lg")
245
-
246
- output = gr.Markdown(label="Results")
247
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  calculate_btn.click(
249
- fn=calculate_vram,
250
- inputs=[model_input, context_input, batch_input],
251
- outputs=output
 
 
 
252
  )
253
-
254
  # Examples
255
  gr.Examples(
256
  examples=[
257
  ["meta-llama/Llama-3.1-8B", 4096, 1],
 
258
  ["mistralai/Mistral-7B-v0.1", 8192, 1],
259
  ["Qwen/Qwen2.5-72B", 32768, 1],
260
  ["google/gemma-2-27b", 8192, 1],
261
  ["microsoft/phi-4", 16384, 1],
 
 
262
  ],
263
  inputs=[model_input, context_input, batch_input],
264
- label="Try these models"
265
  )
266
-
267
  gr.Markdown("""
268
  ---
269
- **Notes:**
270
- - VRAM estimates include ~15% overhead for activations and framework overhead
271
- - KV cache assumes inference (not training)
272
- - Actual requirements may vary based on serving framework (vLLM, TGI, etc.)
273
- - For GGUF models, memory requirements differ significantly
274
-
275
- Built with ❤️ using Gradio & HuggingFace Hub API
 
 
 
 
 
 
276
  """)
277
 
278
 
 
2
  VRAM & Instance Type Calculator for HuggingFace Models
3
 
4
  Fetches model metadata from HF Hub and calculates:
5
+ - Minimum VRAM required for inference and training
6
  - KV cache requirements at various context lengths
7
  - Recommended GPUs and cloud instances
8
+ - Multi-GPU tensor parallelism estimates
9
+ - Quantization options with detailed breakdown
10
  """
11
 
12
  import gradio as gr
13
  from huggingface_hub import HfApi, hf_hub_download
14
  import json
15
+ from functools import lru_cache
16
 
17
  # Initialize HF API client
18
  api = HfApi()
19
 
20
+ # GPU specs: name -> (VRAM in GB, typical cloud instance, category)
21
  GPU_SPECS = {
22
+ # Consumer GPUs
23
+ "RTX 3080": (10, "Consumer", "consumer"),
24
+ "RTX 3090": (24, "Consumer", "consumer"),
25
+ "RTX 4080": (16, "Consumer", "consumer"),
26
+ "RTX 4090": (24, "Consumer", "consumer"),
27
+ "RTX 5090": (32, "Consumer (est.)", "consumer"),
28
+ # Apple Silicon
29
+ "M2 Ultra": (192, "Mac Studio (Unified)", "apple"),
30
+ "M3 Max": (128, "MacBook Pro (Unified)", "apple"),
31
+ "M4 Max": (128, "MacBook Pro (Unified)", "apple"),
32
+ # Workstation GPUs
33
+ "RTX A6000": (48, "Workstation", "workstation"),
34
+ "L40S": (48, "AWS g6.xlarge (~$1.00/hr)", "cloud"),
35
+ # Cloud GPUs
36
+ "A10G": (24, "AWS g5.xlarge (~$1.00/hr)", "cloud"),
37
+ "L4": (24, "GCP g2-standard-4 (~$0.70/hr)", "cloud"),
38
+ "A100 40GB": (40, "AWS p4d, GCP a2-highgpu-1g (~$3/hr)", "cloud"),
39
+ "A100 80GB": (80, "AWS p4de, GCP a2-ultragpu-1g (~$5/hr)", "cloud"),
40
+ "H100 80GB": (80, "AWS p5, GCP a3-highgpu (~$8/hr)", "cloud"),
41
+ "H200 141GB": (141, "Coming soon (~$12/hr est.)", "cloud"),
42
+ # AMD GPUs
43
+ "MI300X": (192, "AMD Cloud Instances", "amd"),
44
  }
45
 
46
  # Bytes per element for different dtypes
 
55
  "I64": 8, "int64": 8,
56
  }
57
 
58
+ # Serving framework overhead multipliers
59
+ SERVING_FRAMEWORKS = {
60
+ "None (raw PyTorch)": 1.20,
61
+ "vLLM": 1.10,
62
+ "TGI (Text Generation Inference)": 1.15,
63
+ "llama.cpp": 1.05,
64
+ "Transformers (HuggingFace)": 1.25,
65
+ "Ollama": 1.08,
66
+ }
67
+
68
+ # Quantization methods with their characteristics
69
+ QUANTIZATION_METHODS = {
70
+ "FP16/BF16": {"bytes_per_param": 2.0, "quality": "100%", "desc": "Full precision"},
71
+ "INT8 (LLM.int8)": {"bytes_per_param": 1.0, "quality": "~99%", "desc": "Good balance"},
72
+ "GPTQ 8-bit": {"bytes_per_param": 1.0, "quality": "~99%", "desc": "GPU optimized"},
73
+ "AWQ 4-bit": {"bytes_per_param": 0.5, "quality": "~97%", "desc": "Activation-aware"},
74
+ "GPTQ 4-bit": {"bytes_per_param": 0.5, "quality": "~95%", "desc": "GPU optimized"},
75
+ "GGUF Q8_0": {"bytes_per_param": 1.0, "quality": "~99%", "desc": "llama.cpp format"},
76
+ "GGUF Q6_K": {"bytes_per_param": 0.75, "quality": "~98%", "desc": "llama.cpp format"},
77
+ "GGUF Q5_K_M": {"bytes_per_param": 0.625, "quality": "~97%", "desc": "llama.cpp format"},
78
+ "GGUF Q4_K_M": {"bytes_per_param": 0.5, "quality": "~95%", "desc": "llama.cpp format"},
79
+ "GGUF Q3_K_M": {"bytes_per_param": 0.375, "quality": "~90%", "desc": "llama.cpp format"},
80
+ "GGUF Q2_K": {"bytes_per_param": 0.3125, "quality": "~85%", "desc": "Aggressive compression"},
81
+ }
82
 
83
+
84
+ def bytes_to_gb(b: int | float) -> float:
85
  return b / (1024 ** 3)
86
 
87
 
88
+ def gb_to_bytes(gb: float) -> float:
89
+ return gb * (1024 ** 3)
90
+
91
+
92
+ @lru_cache(maxsize=50)
93
+ def get_model_info_cached(model_id: str):
94
+ """Fetch model info from HF Hub with caching."""
95
  try:
96
  info = api.model_info(model_id, files_metadata=True)
97
  return info
98
  except Exception as e:
99
+ return {"_error": str(e)}
100
 
101
 
102
+ @lru_cache(maxsize=50)
103
+ def get_config_cached(model_id: str) -> str:
104
+ """Fetch config.json with caching. Returns JSON string for cache compatibility."""
105
  try:
106
  config_path = hf_hub_download(model_id, "config.json")
107
  with open(config_path) as f:
108
+ return f.read()
109
  except Exception as e:
110
+ return json.dumps({"_error": str(e)})
111
+
112
+
113
+ def get_model_info(model_id: str):
114
+ """Fetch model info from HF Hub."""
115
+ result = get_model_info_cached(model_id)
116
+ if isinstance(result, dict) and "_error" in result:
117
+ raise gr.Error(f"Could not fetch model info: {result['_error']}")
118
+ return result
119
+
120
+
121
+ def get_config(model_id: str) -> dict:
122
+ """Get config.json for architecture details."""
123
+ config_str = get_config_cached(model_id)
124
+ return json.loads(config_str)
125
 
126
 
127
  def estimate_params_from_safetensors(info) -> tuple[int, str]:
128
  """Extract parameter count and dtype from safetensors metadata."""
129
+ if hasattr(info, 'safetensors') and info.safetensors:
130
  param_count = info.safetensors.total
 
131
  params_by_dtype = info.safetensors.parameters
132
  if params_by_dtype:
133
  dominant_dtype = max(params_by_dtype, key=params_by_dtype.get)
 
135
  return 0, "F16"
136
 
137
 
138
+ def get_head_dim(config: dict) -> int:
139
+ """Calculate head dimension from config, with fallbacks."""
140
+ # Try to get it directly
141
+ if "head_dim" in config:
142
+ return config["head_dim"]
143
+
144
+ # Calculate from hidden_size and num_attention_heads
145
+ hidden_size = config.get("hidden_size", config.get("n_embd", 0))
146
+ num_heads = config.get("num_attention_heads", config.get("n_head", 0))
147
+
148
+ if hidden_size and num_heads:
149
+ return hidden_size // num_heads
150
+
151
+ # Common defaults by model family
152
+ return 128 # Most common default
153
+
154
+
155
  def estimate_kv_cache_size(
156
  num_layers: int,
 
157
  num_kv_heads: int,
158
+ head_dim: int,
159
  context_length: int,
160
  batch_size: int = 1,
161
  dtype_bytes: int = 2
162
  ) -> int:
163
  """
164
  KV cache size = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
165
+
166
+ The 2 accounts for both K and V caches.
167
  """
 
 
 
 
 
 
 
 
 
 
 
168
  kv_cache_bytes = 2 * num_layers * batch_size * context_length * num_kv_heads * head_dim * dtype_bytes
169
  return kv_cache_bytes
170
 
171
 
172
+ def estimate_training_memory(
173
+ param_count: int,
174
+ dtype_bytes: int,
175
+ optimizer: str = "AdamW"
176
+ ) -> dict:
177
+ """
178
+ Estimate training memory requirements.
179
+
180
+ For training, we need:
181
+ - Model weights
182
+ - Gradients (same size as weights)
183
+ - Optimizer states (varies by optimizer)
184
+ - Activations (highly variable, estimated)
185
+ """
186
+ weights_bytes = param_count * dtype_bytes
187
+ gradients_bytes = param_count * dtype_bytes
188
+
189
+ # Optimizer states
190
+ if optimizer == "AdamW":
191
+ # AdamW stores: m (momentum), v (variance) in FP32
192
+ optimizer_bytes = param_count * 4 * 2 # 2 states, 4 bytes each
193
+ elif optimizer == "SGD":
194
+ optimizer_bytes = 0 # No extra state (momentum optional)
195
+ elif optimizer == "SGD + Momentum":
196
+ optimizer_bytes = param_count * 4 # Momentum buffer
197
+ elif optimizer == "8-bit Adam":
198
+ optimizer_bytes = param_count * 1 * 2 # 2 states, 1 byte each
199
+ else:
200
+ optimizer_bytes = param_count * 4 * 2 # Default to AdamW
201
+
202
+ return {
203
+ "weights": weights_bytes,
204
+ "gradients": gradients_bytes,
205
+ "optimizer": optimizer_bytes,
206
+ "total_base": weights_bytes + gradients_bytes + optimizer_bytes
207
+ }
208
+
209
+
210
+ def calculate_multi_gpu_split(total_vram_gb: float, num_gpus: int, parallelism: str) -> dict:
211
+ """Calculate memory distribution across multiple GPUs."""
212
+ if parallelism == "Tensor Parallelism":
213
+ # Weights and KV cache split evenly
214
+ per_gpu = total_vram_gb / num_gpus
215
+ overhead = 0.05 * total_vram_gb # Communication overhead
216
+ return {
217
+ "per_gpu": per_gpu + (overhead / num_gpus),
218
+ "total": total_vram_gb + overhead,
219
+ "efficiency": "High (best for inference)",
220
+ }
221
+ elif parallelism == "Pipeline Parallelism":
222
+ # Layers distributed, but activation memory at boundaries
223
+ per_gpu = total_vram_gb / num_gpus
224
+ overhead = 0.1 * total_vram_gb # Activation memory overhead
225
+ return {
226
+ "per_gpu": per_gpu + (overhead / num_gpus),
227
+ "total": total_vram_gb + overhead,
228
+ "efficiency": "Medium (good for training)",
229
+ }
230
+ else: # Data Parallelism
231
+ # Full model on each GPU
232
+ return {
233
+ "per_gpu": total_vram_gb,
234
+ "total": total_vram_gb * num_gpus,
235
+ "efficiency": "Low memory efficiency (training only)",
236
+ }
237
+
238
+
239
+ def calculate_vram(
240
+ model_id: str,
241
+ context_length: int = 4096,
242
+ batch_size: int = 1,
243
+ mode: str = "Inference",
244
+ optimizer: str = "AdamW",
245
+ serving_framework: str = "None (raw PyTorch)",
246
+ num_gpus: int = 1,
247
+ parallelism: str = "Tensor Parallelism"
248
+ ) -> tuple[str, dict | None]:
249
+ """Main calculation function. Returns (markdown_results, chart_data)."""
250
+
251
+ # Validate inputs
252
+ model_id = model_id.strip()
253
+ if not model_id:
254
+ raise gr.Error("Please enter a model ID")
255
+
256
+ if "/" not in model_id:
257
+ raise gr.Error("Model ID should be in format 'organization/model-name'")
258
+
259
  # Fetch model info
260
  info = get_model_info(model_id)
261
  config = get_config(model_id)
262
+
263
  results = []
264
  results.append(f"## Model: [{model_id}](https://huggingface.co/{model_id})\n")
265
+
266
  # Get parameter count and dtype
267
  param_count, dominant_dtype = estimate_params_from_safetensors(info)
268
+
269
  if param_count == 0:
 
270
  results.append("⚠️ Could not determine parameter count from safetensors metadata.\n")
271
  results.append("Model may use pytorch_model.bin or other format.\n")
272
+ return "\n".join(results), None
273
+
274
  dtype_bytes = DTYPE_BYTES.get(dominant_dtype, 2)
275
  params_b = param_count / 1e9
276
+
277
+ results.append(f"**Parameters:** {params_b:.2f}B ({param_count:,})")
278
+ results.append(f"**Dominant dtype:** {dominant_dtype} ({dtype_bytes} bytes/param)")
279
+ results.append(f"**Mode:** {mode}")
280
+
281
  # Model weights VRAM
282
  weights_bytes = param_count * dtype_bytes
283
  weights_gb = bytes_to_gb(weights_bytes)
284
+ results.append(f"\n### 📦 Weight Memory")
285
  results.append(f"Model weights: **{weights_gb:.2f} GB**")
286
+
287
+ # Architecture details
288
  num_layers = config.get("num_hidden_layers", config.get("n_layer", 0))
289
  hidden_size = config.get("hidden_size", config.get("n_embd", 0))
290
+ num_attention_heads = config.get("num_attention_heads", config.get("n_head", 0))
291
+ num_kv_heads = config.get("num_key_value_heads", num_attention_heads)
292
+ head_dim = get_head_dim(config)
293
+ max_position = config.get("max_position_embeddings", config.get("n_positions", "N/A"))
294
+
295
+ results.append(f"\n### 🏗️ Architecture (from config.json)")
296
  if "_error" in config:
297
+ results.append(f"⚠️ Could not fetch config.json (model may be gated)")
298
+ kv_gb = 0
299
  elif num_layers and hidden_size:
300
+ results.append(f"- **Layers:** {num_layers}")
301
+ results.append(f"- **Hidden size:** {hidden_size}")
302
+ results.append(f"- **Attention heads:** {num_attention_heads}")
303
+ results.append(f"- **KV heads:** {num_kv_heads} {'(GQA)' if num_kv_heads != num_attention_heads else '(MHA)'}")
304
+ results.append(f"- **Head dimension:** {head_dim}")
305
+ results.append(f"- **Max context:** {max_position:,}" if isinstance(max_position, int) else f"- **Max context:** {max_position}")
306
+
307
+ # KV Cache calculation
308
+ results.append(f"\n### 💾 KV Cache (batch_size={batch_size})")
309
+ results.append("| Context | KV Cache | + Weights | Status |")
310
+ results.append("|---------|----------|-----------|--------|")
311
+
312
+ # Show relevant context lengths
313
+ context_points = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
314
+ for ctx_len in context_points:
315
+ if ctx_len > context_length * 2 and ctx_len > 8192:
316
  break
317
  kv_bytes = estimate_kv_cache_size(
318
+ num_layers, num_kv_heads, head_dim, ctx_len, batch_size, dtype_bytes
319
  )
320
+ kv_gb_temp = bytes_to_gb(kv_bytes)
321
+ total_temp = weights_gb + kv_gb_temp
322
+ marker = " **← selected**" if ctx_len == context_length else ""
323
+ results.append(f"| {ctx_len:,} | {kv_gb_temp:.2f} GB | {total_temp:.2f} GB |{marker} |")
324
+
325
+ # Calculate for selected context
 
 
 
326
  kv_bytes = estimate_kv_cache_size(
327
+ num_layers, num_kv_heads, head_dim, context_length, batch_size, dtype_bytes
328
  )
329
  kv_gb = bytes_to_gb(kv_bytes)
 
330
  else:
331
+ results.append("Could not find architecture details")
332
+ kv_gb = 0
333
+
334
+ # Calculate total based on mode
335
+ if mode == "Training":
336
+ training_mem = estimate_training_memory(param_count, dtype_bytes, optimizer)
337
+ base_gb = bytes_to_gb(training_mem["total_base"])
338
+
339
+ # Activations estimation (rough: ~2x weights for typical batch)
340
+ activation_gb = weights_gb * 2 * batch_size
341
+ total_gb = base_gb + kv_gb + activation_gb
342
+
343
+ results.append(f"\n### 🎓 Training Memory Breakdown")
344
+ results.append(f"- **Weights:** {weights_gb:.2f} GB")
345
+ results.append(f"- **Gradients:** {bytes_to_gb(training_mem['gradients']):.2f} GB")
346
+ results.append(f"- **Optimizer ({optimizer}):** {bytes_to_gb(training_mem['optimizer']):.2f} GB")
347
+ results.append(f"- **KV Cache:** {kv_gb:.2f} GB")
348
+ results.append(f"- **Activations (est.):** {activation_gb:.2f} GB")
349
+
350
+ chart_data = {
351
+ "Weights": weights_gb,
352
+ "Gradients": bytes_to_gb(training_mem['gradients']),
353
+ "Optimizer": bytes_to_gb(training_mem['optimizer']),
354
+ "KV Cache": kv_gb,
355
+ "Activations": activation_gb,
356
+ }
357
+ else:
358
+ # Inference mode
359
+ framework_overhead = SERVING_FRAMEWORKS.get(serving_framework, 1.15)
360
+ base_total = weights_gb + kv_gb
361
+ overhead_gb = base_total * (framework_overhead - 1)
362
+ total_gb = base_total + overhead_gb
363
+
364
+ results.append(f"\n### ⚡ Inference Memory ({serving_framework})")
365
+ results.append(f"- **Weights:** {weights_gb:.2f} GB")
366
+ results.append(f"- **KV Cache:** {kv_gb:.2f} GB")
367
+ results.append(f"- **Framework overhead:** {overhead_gb:.2f} GB ({(framework_overhead-1)*100:.0f}%)")
368
+
369
+ chart_data = {
370
+ "Weights": weights_gb,
371
+ "KV Cache": kv_gb,
372
+ "Overhead": overhead_gb,
373
+ }
374
+
375
+ results.append(f"\n### 📊 Total VRAM Required: **{total_gb:.2f} GB**")
376
+
377
+ # Multi-GPU calculations
378
+ if num_gpus > 1:
379
+ multi_gpu = calculate_multi_gpu_split(total_gb, num_gpus, parallelism)
380
+ results.append(f"\n### 🔗 Multi-GPU ({num_gpus}x GPUs, {parallelism})")
381
+ results.append(f"- **Per GPU:** {multi_gpu['per_gpu']:.2f} GB")
382
+ results.append(f"- **Total across GPUs:** {multi_gpu['total']:.2f} GB")
383
+ results.append(f"- **Efficiency:** {multi_gpu['efficiency']}")
384
+
385
+ # Update total for GPU recommendations
386
+ effective_vram_needed = multi_gpu['per_gpu']
387
+ else:
388
+ effective_vram_needed = total_gb
389
+
390
  # GPU Recommendations
391
+ results.append(f"\n### 🎮 GPU Recommendations")
392
+ results.append("| GPU | VRAM | Fits? | Headroom | Instance |")
393
+ results.append("|-----|------|-------|----------|----------|")
394
+
395
+ for gpu_name, (vram, instance, category) in GPU_SPECS.items():
396
+ fits = "✅" if vram >= effective_vram_needed else "❌"
397
+ headroom = vram - effective_vram_needed
398
+ headroom_str = f"+{headroom:.1f} GB" if headroom > 0 else f"{headroom:.1f} GB"
399
+ results.append(f"| {gpu_name} | {vram} GB | {fits} | {headroom_str} | {instance} |")
400
+
401
+ # Quantization options (if model doesn't fit on consumer GPUs)
402
+ if effective_vram_needed > 24:
403
+ results.append(f"\n### 🗜️ Quantization Options")
404
+ results.append("To fit on consumer GPUs (≤24 GB), consider these options:\n")
405
+ results.append("| Method | Est. Size | Quality | Notes |")
406
+ results.append("|--------|-----------|---------|-------|")
407
+
408
+ for method, specs in QUANTIZATION_METHODS.items():
409
+ quant_size = bytes_to_gb(param_count * specs["bytes_per_param"])
410
+ quant_with_overhead = quant_size * 1.1 # Small overhead
411
+ fits = "✅" if quant_with_overhead <= 24 else "❌"
412
+ results.append(f"| {method} | {quant_with_overhead:.1f} GB | {specs['quality']} | {fits} {specs['desc']} |")
413
+
414
+ results.append(f"\n**Tip:** Search for `{model_id.split('/')[-1]} GGUF` or `{model_id.split('/')[-1]} AWQ` on HuggingFace.")
415
+
416
+ return "\n".join(results), chart_data
417
+
418
+
419
+ def create_memory_chart(chart_data: dict | None):
420
+ """Create a bar chart for memory breakdown."""
421
+ if not chart_data:
422
+ return None
423
+
424
+ labels = list(chart_data.keys())
425
+ values = list(chart_data.values())
426
+
427
+ return gr.BarPlot(
428
+ value={"Component": labels, "GB": values},
429
+ x="Component",
430
+ y="GB",
431
+ title="Memory Breakdown",
432
+ height=300,
433
+ width=400,
434
+ )
435
 
436
 
437
  # Build Gradio interface
438
+ with gr.Blocks(title="VRAM Calculator", theme=gr.themes.Soft()) as demo:
439
  gr.Markdown("""
440
  # 🧮 VRAM & Instance Type Calculator
441
+
442
+ Estimate GPU memory requirements for HuggingFace models. Supports inference and training modes,
443
+ multi-GPU setups, and provides detailed quantization recommendations.
 
444
  """)
445
+
446
  with gr.Row():
447
  with gr.Column(scale=2):
448
  model_input = gr.Textbox(
449
  label="Model ID",
450
  placeholder="meta-llama/Llama-3.1-8B",
451
+ info="Full HuggingFace model ID (org/model-name)"
452
+ )
453
+
454
+ with gr.Row():
455
+ with gr.Column(scale=1):
456
+ mode_input = gr.Radio(
457
+ choices=["Inference", "Training"],
458
+ value="Inference",
459
+ label="Mode",
460
+ info="Training requires ~4x more memory"
461
  )
462
  with gr.Column(scale=1):
463
  context_input = gr.Slider(
 
466
  maximum=131072,
467
  value=4096,
468
  step=512,
469
+ info="Sequence length for KV cache"
470
  )
471
  with gr.Column(scale=1):
472
  batch_input = gr.Slider(
473
  label="Batch Size",
474
  minimum=1,
475
+ maximum=64,
476
  value=1,
477
  step=1,
478
  info="Concurrent sequences"
479
  )
480
+
481
+ with gr.Accordion("⚙️ Advanced Options", open=False):
482
+ with gr.Row():
483
+ with gr.Column():
484
+ serving_input = gr.Dropdown(
485
+ choices=list(SERVING_FRAMEWORKS.keys()),
486
+ value="None (raw PyTorch)",
487
+ label="Serving Framework",
488
+ info="Different frameworks have different overhead"
489
+ )
490
+ optimizer_input = gr.Dropdown(
491
+ choices=["AdamW", "SGD", "SGD + Momentum", "8-bit Adam"],
492
+ value="AdamW",
493
+ label="Optimizer (Training mode)",
494
+ info="Optimizer state memory varies"
495
+ )
496
+ with gr.Column():
497
+ num_gpus_input = gr.Slider(
498
+ label="Number of GPUs",
499
+ minimum=1,
500
+ maximum=8,
501
+ value=1,
502
+ step=1,
503
+ info="For multi-GPU setups"
504
+ )
505
+ parallelism_input = gr.Dropdown(
506
+ choices=["Tensor Parallelism", "Pipeline Parallelism", "Data Parallelism"],
507
+ value="Tensor Parallelism",
508
+ label="Parallelism Strategy",
509
+ info="How to distribute across GPUs"
510
+ )
511
+
512
+ calculate_btn = gr.Button("🚀 Calculate VRAM", variant="primary", size="lg")
513
+
514
+ with gr.Row():
515
+ with gr.Column(scale=3):
516
+ output = gr.Markdown(label="Results")
517
+ with gr.Column(scale=1):
518
+ chart_output = gr.BarPlot(
519
+ x="Component",
520
+ y="GB",
521
+ title="Memory Breakdown",
522
+ height=350,
523
+ )
524
+
525
+ def run_calculation(model_id, context_length, batch_size, mode, optimizer, serving, num_gpus, parallelism):
526
+ result_text, chart_data = calculate_vram(
527
+ model_id, context_length, batch_size, mode, optimizer, serving, num_gpus, parallelism
528
+ )
529
+ if chart_data:
530
+ import pandas as pd
531
+ df = pd.DataFrame({
532
+ "Component": list(chart_data.keys()),
533
+ "GB": list(chart_data.values())
534
+ })
535
+ return result_text, df
536
+ return result_text, None
537
+
538
  calculate_btn.click(
539
+ fn=run_calculation,
540
+ inputs=[
541
+ model_input, context_input, batch_input, mode_input,
542
+ optimizer_input, serving_input, num_gpus_input, parallelism_input
543
+ ],
544
+ outputs=[output, chart_output]
545
  )
546
+
547
  # Examples
548
  gr.Examples(
549
  examples=[
550
  ["meta-llama/Llama-3.1-8B", 4096, 1],
551
+ ["meta-llama/Llama-3.1-70B", 8192, 1],
552
  ["mistralai/Mistral-7B-v0.1", 8192, 1],
553
  ["Qwen/Qwen2.5-72B", 32768, 1],
554
  ["google/gemma-2-27b", 8192, 1],
555
  ["microsoft/phi-4", 16384, 1],
556
+ ["deepseek-ai/DeepSeek-V3", 4096, 1],
557
+ ["meta-llama/Llama-3.3-70B-Instruct", 8192, 1],
558
  ],
559
  inputs=[model_input, context_input, batch_input],
560
+ label="🔥 Popular Models"
561
  )
562
+
563
  gr.Markdown("""
564
  ---
565
+ ### 📝 Notes
566
+ - **Inference mode:** Weights + KV cache + framework overhead
567
+ - **Training mode:** Adds gradients, optimizer states, and activation memory
568
+ - **KV cache:** Scales linearly with context length and batch size
569
+ - **Multi-GPU:** Tensor parallelism splits memory; data parallelism replicates it
570
+ - **Quantization:** GGUF/AWQ/GPTQ can reduce memory 2-8x with minimal quality loss
571
+
572
+ ### ⚠️ Disclaimers
573
+ - Estimates are approximate; actual usage varies by implementation
574
+ - Flash Attention and other optimizations can significantly reduce memory
575
+ - GGUF models have different memory profiles than safetensors
576
+
577
+ Built with 💜 using Gradio & HuggingFace Hub API
578
  """)
579
 
580
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio>=4.44.0
2
  huggingface_hub>=0.20.0,<1.0.0
 
 
1
  gradio>=4.44.0
2
  huggingface_hub>=0.20.0,<1.0.0
3
+ pandas>=2.0.0