rubenaghayan's picture
better defaults and validation section
64abcca
import json
from pathlib import Path
import gradio as gr
import pandas as pd
from functools import partial
from defaults import DEFAULTS
from details import ACCURACY, DETAILS, INSTRUCTIONS, LIMITATIONS
from state import Model, Parallelism, Training
from calculator import MemoryCalculation
from dtypes import DType
# Create a Number component for natural numbers (positive integers)
NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)
def create_parallelism_block():
with gr.Column():
gr.Markdown("# Parallelism")
with gr.Group():
tp = NaturalNumber(label="Tensor Parallelism", value=1)
pp = NaturalNumber(label="Pipeline Parallelism", value=1)
cp = NaturalNumber(label="Context Parallelism", value=1)
ep = NaturalNumber(label="Expert Parallelism", value=1)
fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=True)
fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=8)
fsdp_strategy = gr.Radio(
choices=["Zero-1", "Zero-2", "Zero-3"],
label="FSDP Strategy",
value="Zero-3"
)
# Toggle FSDP fields interactivity based on FSDP checkbox
fsdp_enabled.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=fsdp_enabled,
outputs=[fsdp_parallelism, fsdp_strategy]
)
return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy
def create_model_block():
with gr.Column():
gr.Markdown("# Model Architecture")
layers = NaturalNumber(label="Number of Layers", value=32)
vocab = NaturalNumber(label="Vocab Size", value=128256)
hidden = NaturalNumber(label="Hidden Dim", value=4096)
intermediate = NaturalNumber(label="Intermediate Dim", value=14336)
is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
active_experts = NaturalNumber(label="Active Experts", value=1, interactive=False, elem_classes="disabled-field")
total_experts = NaturalNumber(label="Total Experts", value=1, interactive=False, elem_classes="disabled-field")
weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True)
# Toggle expert fields interactivity based on MoE checkbox
is_moe.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=is_moe,
outputs=[active_experts, total_experts]
)
presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Llama3 8B", interactive=True)
# Populate model parameters when preset is selected
def populate_from_preset(preset_name):
if preset_name and preset_name in DEFAULTS:
model = DEFAULTS[preset_name]
return [
gr.update(value=model.num_layers),
gr.update(value=model.vocab_size),
gr.update(value=model.hidden_dim),
gr.update(value=model.intermediate_size),
gr.update(value=model.is_moe),
gr.update(value=model.active_experts, interactive=model.is_moe),
gr.update(value=model.total_experts, interactive=model.is_moe),
gr.update(value=model.weight_tied_embeddings)
]
return [gr.update() for _ in range(8)]
# Switch to "Custom" when user manually edits values
def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset):
# Don't switch to custom if a preset is being applied
if current_preset and current_preset in DEFAULTS:
model = DEFAULTS[current_preset]
# Check if current values match the preset exactly
if (layers_val == model.num_layers and
vocab_val == model.vocab_size and
hidden_val == model.hidden_dim and
intermediate_val == model.intermediate_size and
is_moe_val == model.is_moe and
active_experts_val == model.active_experts and
total_experts_val == model.total_experts and
weight_tied_val == model.weight_tied_embeddings):
return gr.update() # Keep current preset
return gr.update(value="Custom")
presets.change(
fn=populate_from_preset,
inputs=presets,
outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]
)
# Add change listeners to all model parameter inputs
for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]:
input_component.change(
fn=switch_to_custom,
inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets],
outputs=presets
)
return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings
def create_training_block():
with gr.Column():
gr.Markdown("# Training Config")
seq_len = NaturalNumber(label="Sequence Length", value=4096)
batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=1)
with gr.Row():
gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=True)
grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
precision = gr.Dropdown(DType.values(), label="Precision", value=DType.BF16.value, interactive=True)
mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
# Toggle dtype fields interactivity based on mixed precision checkbox
mixed_precision.change(
fn=lambda x: [
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
],
inputs=mixed_precision,
outputs=[param_dtype, reduce_dtype]
)
return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype
def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
# Create state objects
model_config = Model(
vocab_size=int(vocab),
num_layers=int(layers),
hidden_dim=int(hidden),
intermediate_size=int(intermediate),
weight_tied_embeddings=weight_tied_embeddings,
active_experts=int(active_experts),
total_experts=int(total_experts),
is_moe=is_moe
)
parallelism_config = Parallelism(
tensor_parallelism=int(tp),
pipeline_parallelism=int(pp),
context_parallelism=int(cp),
expert_parallelism=int(ep),
fsdp_enabled=fsdp_enabled,
fsdp_parallelism=int(fsdp_parallelism),
fsdp_strategy=fsdp_strategy
)
training_config = Training(
sequence_length=int(seq_len),
batch_size=int(batch_size),
gradient_checkpointing=gradient_checkpointing,
grad_accumulation=grad_accumulation,
precision=DType(precision),
mixed_precision=mixed_precision,
param_dtype=DType(param_dtype),
reduce_dtype=DType(reduce_dtype)
)
# Calculate different memory components
calc = MemoryCalculation(model_config, parallelism_config, training_config)
# Get all memory calculations
param_memory = calc.calculate_parameter_memory()
activation_memory = calc.calculate_activation_memory()
gradient_memory = calc.calculate_gradient_memory()
optimizer_memory = calc.calculate_optimizer_memory()
# Calculate total memory
total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory
# Round to 1 decimal place for display
param_gb = round(param_memory / 1e9, 1)
activation_gb = round(activation_memory / 1e9, 1)
gradient_gb = round(gradient_memory / 1e9, 1)
optimizer_gb = round(optimizer_memory / 1e9, 1)
total_gb = round(total_memory / 1e9, 1)
# Create DataFrame for stacked bar plot
# Start with stacked total bar, then add individual bars
individual_data = []
# Stacked total bar first - create separate rows for each component within total
for mem_type, gb_val in [
('Activation', activation_gb),
('Optimizer', optimizer_gb),
('Gradient', gradient_gb),
('Parameter', param_gb)
]:
individual_data.append({
'Component': f'Total Memory\n{total_gb} GB',
'Memory (GB)': gb_val,
'Type': mem_type
})
# Individual component bars
for component, gb_val, mem_type in [
(f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'),
(f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'),
(f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'),
(f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation')
]:
individual_data.append({
'Component': component,
'Memory (GB)': gb_val,
'Type': mem_type
})
memory_data = pd.DataFrame(individual_data)
return gr.BarPlot(
value=memory_data,
x="Component",
y="Memory (GB)",
color="Type",
title="LLM Memory Usage Breakdown",
container=False,
y_lim=[0, None],
sort=[
f'Total Memory\n{total_gb} GB',
f'Parameter Memory\n{param_gb} GB',
f'Gradient Memory\n{gradient_gb} GB',
f'Optimizer Memory\n{optimizer_gb} GB',
f'Activation Memory\n{activation_gb} GB'
]
)
css = """
/* Style for disabled components to make them visually obvious */
.disabled-field input,
.disabled-field select,
.disabled-field textarea {
opacity: 0.4 !important;
background-color: #f5f5f5 !important;
color: #999 !important;
cursor: not-allowed !important;
text-decoration: line-through;
}
.disabled-field label {
opacity: 0.5 !important;
color: #999 !important;
}
"""
with gr.Blocks(theme='Default', css=css) as demo:
with gr.Column():
gr.Markdown("# LLM Training Memory Visualizer")
gr.Markdown("<sub>🔧 Built by [Ruben Aghayan](https://www.linkedin.com/in/ruben-aghayan-37885690/)</sub>")
gr.Markdown("---")
gr.Markdown(INSTRUCTIONS)
with gr.Row(equal_height=True):
tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block()
layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block()
seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
calculate_button = gr.Button("Calculate")
output = gr.BarPlot(label="Memory Usage Breakdown")
calculate_button.click(
fn=calculate,
inputs=[
tp,
pp,
cp,
ep,
fsdp_enabled,
fsdp_parallelism,
fsdp_strategy,
layers,
vocab,
hidden,
intermediate,
active_experts,
total_experts,
is_moe,
weight_tied_embeddings,
seq_len,
batch_size,
gradient_checkpointing,
grad_accumulation,
precision,
mixed_precision,
param_dtype,
reduce_dtype,
],
outputs=output,
)
gr.Markdown("# Details")
with gr.Row():
gr.Markdown(LIMITATIONS)
gr.Markdown(DETAILS)
gr.Markdown("# Validation")
gr.Markdown(ACCURACY)
demo.launch(share=True)