Spaces:
Sleeping
Sleeping
File size: 13,280 Bytes
75dbc58 84f0b80 97e312a 84f0b80 64abcca 97e312a 84f0b80 97e312a f45427d 64abcca f45427d 64abcca f45427d 84f0b80 b79954f 84f0b80 97e312a 64abcca 97e312a 64abcca f45427d 97e312a f45427d 97e312a f45427d 97e312a b79954f 97e312a 64abcca f45427d 84f0b80 b79954f 84f0b80 97e312a 64abcca 97e312a 64abcca 97e312a 64abcca 97e312a f45427d 97e312a f45427d 97e312a f45427d 97e312a 84f0b80 f45427d 97e312a f45427d 97e312a b79954f 97e312a f45427d 97e312a 84f0b80 97e312a 84f0b80 97e312a f45427d 97e312a f45427d 97e312a f45427d 97e312a f45427d 97e312a f9d6101 84f0b80 75dbc58 f9d6101 97e312a f45427d 97e312a 84f0b80 97e312a b79954f 97e312a 75dbc58 97e312a b79954f 75dbc58 f45427d 75dbc58 64abcca f45427d 64abcca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
import json
from pathlib import Path
import gradio as gr
import pandas as pd
from functools import partial
from defaults import DEFAULTS
from details import ACCURACY, DETAILS, INSTRUCTIONS, LIMITATIONS
from state import Model, Parallelism, Training
from calculator import MemoryCalculation
from dtypes import DType
# Create a Number component for natural numbers (positive integers)
NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)
def create_parallelism_block():
with gr.Column():
gr.Markdown("# Parallelism")
with gr.Group():
tp = NaturalNumber(label="Tensor Parallelism", value=1)
pp = NaturalNumber(label="Pipeline Parallelism", value=1)
cp = NaturalNumber(label="Context Parallelism", value=1)
ep = NaturalNumber(label="Expert Parallelism", value=1)
fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=True)
fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=8)
fsdp_strategy = gr.Radio(
choices=["Zero-1", "Zero-2", "Zero-3"],
label="FSDP Strategy",
value="Zero-3"
)
# Toggle FSDP fields interactivity based on FSDP checkbox
fsdp_enabled.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=fsdp_enabled,
outputs=[fsdp_parallelism, fsdp_strategy]
)
return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy
def create_model_block():
with gr.Column():
gr.Markdown("# Model Architecture")
layers = NaturalNumber(label="Number of Layers", value=32)
vocab = NaturalNumber(label="Vocab Size", value=128256)
hidden = NaturalNumber(label="Hidden Dim", value=4096)
intermediate = NaturalNumber(label="Intermediate Dim", value=14336)
is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
active_experts = NaturalNumber(label="Active Experts", value=1, interactive=False, elem_classes="disabled-field")
total_experts = NaturalNumber(label="Total Experts", value=1, interactive=False, elem_classes="disabled-field")
weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True)
# Toggle expert fields interactivity based on MoE checkbox
is_moe.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=is_moe,
outputs=[active_experts, total_experts]
)
presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Llama3 8B", interactive=True)
# Populate model parameters when preset is selected
def populate_from_preset(preset_name):
if preset_name and preset_name in DEFAULTS:
model = DEFAULTS[preset_name]
return [
gr.update(value=model.num_layers),
gr.update(value=model.vocab_size),
gr.update(value=model.hidden_dim),
gr.update(value=model.intermediate_size),
gr.update(value=model.is_moe),
gr.update(value=model.active_experts, interactive=model.is_moe),
gr.update(value=model.total_experts, interactive=model.is_moe),
gr.update(value=model.weight_tied_embeddings)
]
return [gr.update() for _ in range(8)]
# Switch to "Custom" when user manually edits values
def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset):
# Don't switch to custom if a preset is being applied
if current_preset and current_preset in DEFAULTS:
model = DEFAULTS[current_preset]
# Check if current values match the preset exactly
if (layers_val == model.num_layers and
vocab_val == model.vocab_size and
hidden_val == model.hidden_dim and
intermediate_val == model.intermediate_size and
is_moe_val == model.is_moe and
active_experts_val == model.active_experts and
total_experts_val == model.total_experts and
weight_tied_val == model.weight_tied_embeddings):
return gr.update() # Keep current preset
return gr.update(value="Custom")
presets.change(
fn=populate_from_preset,
inputs=presets,
outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]
)
# Add change listeners to all model parameter inputs
for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]:
input_component.change(
fn=switch_to_custom,
inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets],
outputs=presets
)
return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings
def create_training_block():
with gr.Column():
gr.Markdown("# Training Config")
seq_len = NaturalNumber(label="Sequence Length", value=4096)
batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=1)
with gr.Row():
gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=True)
grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
precision = gr.Dropdown(DType.values(), label="Precision", value=DType.BF16.value, interactive=True)
mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
# Toggle dtype fields interactivity based on mixed precision checkbox
mixed_precision.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=mixed_precision,
outputs=[param_dtype, reduce_dtype]
)
return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype
def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
# Create state objects
model_config = Model(
vocab_size=int(vocab),
num_layers=int(layers),
hidden_dim=int(hidden),
intermediate_size=int(intermediate),
weight_tied_embeddings=weight_tied_embeddings,
active_experts=int(active_experts),
total_experts=int(total_experts),
is_moe=is_moe
)
parallelism_config = Parallelism(
tensor_parallelism=int(tp),
pipeline_parallelism=int(pp),
context_parallelism=int(cp),
expert_parallelism=int(ep),
fsdp_enabled=fsdp_enabled,
fsdp_parallelism=int(fsdp_parallelism),
fsdp_strategy=fsdp_strategy
)
training_config = Training(
sequence_length=int(seq_len),
batch_size=int(batch_size),
gradient_checkpointing=gradient_checkpointing,
grad_accumulation=grad_accumulation,
precision=DType(precision),
mixed_precision=mixed_precision,
param_dtype=DType(param_dtype),
reduce_dtype=DType(reduce_dtype)
)
# Calculate different memory components
calc = MemoryCalculation(model_config, parallelism_config, training_config)
# Get all memory calculations
param_memory = calc.calculate_parameter_memory()
activation_memory = calc.calculate_activation_memory()
gradient_memory = calc.calculate_gradient_memory()
optimizer_memory = calc.calculate_optimizer_memory()
# Calculate total memory
total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory
# Round to 1 decimal place for display
param_gb = round(param_memory / 1e9, 1)
activation_gb = round(activation_memory / 1e9, 1)
gradient_gb = round(gradient_memory / 1e9, 1)
optimizer_gb = round(optimizer_memory / 1e9, 1)
total_gb = round(total_memory / 1e9, 1)
# Create DataFrame for stacked bar plot
# Start with stacked total bar, then add individual bars
individual_data = []
# Stacked total bar first - create separate rows for each component within total
for mem_type, gb_val in [
('Activation', activation_gb),
('Optimizer', optimizer_gb),
('Gradient', gradient_gb),
('Parameter', param_gb)
]:
individual_data.append({
'Component': f'Total Memory\n{total_gb} GB',
'Memory (GB)': gb_val,
'Type': mem_type
})
# Individual component bars
for component, gb_val, mem_type in [
(f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'),
(f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'),
(f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'),
(f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation')
]:
individual_data.append({
'Component': component,
'Memory (GB)': gb_val,
'Type': mem_type
})
memory_data = pd.DataFrame(individual_data)
return gr.BarPlot(
value=memory_data,
x="Component",
y="Memory (GB)",
color="Type",
title="LLM Memory Usage Breakdown",
container=False,
y_lim=[0, None],
sort=[
f'Total Memory\n{total_gb} GB',
f'Parameter Memory\n{param_gb} GB',
f'Gradient Memory\n{gradient_gb} GB',
f'Optimizer Memory\n{optimizer_gb} GB',
f'Activation Memory\n{activation_gb} GB'
]
)
css = """
/* Style for disabled components to make them visually obvious */
.disabled-field input,
.disabled-field select,
.disabled-field textarea {
opacity: 0.4 !important;
background-color: #f5f5f5 !important;
color: #999 !important;
cursor: not-allowed !important;
text-decoration: line-through;
}
.disabled-field label {
opacity: 0.5 !important;
color: #999 !important;
}
"""
with gr.Blocks(theme='Default', css=css) as demo:
with gr.Column():
gr.Markdown("# LLM Training Memory Visualizer")
gr.Markdown("<sub>🔧 Built by [Ruben Aghayan](https://www.linkedin.com/in/ruben-aghayan-37885690/)</sub>")
gr.Markdown("---")
gr.Markdown(INSTRUCTIONS)
with gr.Row(equal_height=True):
tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block()
layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block()
seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
calculate_button = gr.Button("Calculate")
output = gr.BarPlot(label="Memory Usage Breakdown")
calculate_button.click(
fn=calculate,
inputs=[
tp,
pp,
cp,
ep,
fsdp_enabled,
fsdp_parallelism,
fsdp_strategy,
layers,
vocab,
hidden,
intermediate,
active_experts,
total_experts,
is_moe,
weight_tied_embeddings,
seq_len,
batch_size,
gradient_checkpointing,
grad_accumulation,
precision,
mixed_precision,
param_dtype,
reduce_dtype,
],
outputs=output,
)
gr.Markdown("# Details")
with gr.Row():
gr.Markdown(LIMITATIONS)
gr.Markdown(DETAILS)
gr.Markdown("# Validation")
gr.Markdown(ACCURACY)
demo.launch(share=True)
|