Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| from functools import partial | |
| from defaults import DEFAULTS | |
| from details import ACCURACY, DETAILS, INSTRUCTIONS, LIMITATIONS | |
| from state import Model, Parallelism, Training | |
| from calculator import MemoryCalculation | |
| from dtypes import DType | |
| # Create a Number component for natural numbers (positive integers) | |
| NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True) | |
| def create_parallelism_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Parallelism") | |
| with gr.Group(): | |
| tp = NaturalNumber(label="Tensor Parallelism", value=1) | |
| pp = NaturalNumber(label="Pipeline Parallelism", value=1) | |
| cp = NaturalNumber(label="Context Parallelism", value=1) | |
| ep = NaturalNumber(label="Expert Parallelism", value=1) | |
| fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=True) | |
| fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=8) | |
| fsdp_strategy = gr.Radio( | |
| choices=["Zero-1", "Zero-2", "Zero-3"], | |
| label="FSDP Strategy", | |
| value="Zero-3" | |
| ) | |
| # Toggle FSDP fields interactivity based on FSDP checkbox | |
| fsdp_enabled.change( | |
| fn=lambda x: [ | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) | |
| ], | |
| inputs=fsdp_enabled, | |
| outputs=[fsdp_parallelism, fsdp_strategy] | |
| ) | |
| return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy | |
| def create_model_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Model Architecture") | |
| layers = NaturalNumber(label="Number of Layers", value=32) | |
| vocab = NaturalNumber(label="Vocab Size", value=128256) | |
| hidden = NaturalNumber(label="Hidden Dim", value=4096) | |
| intermediate = NaturalNumber(label="Intermediate Dim", value=14336) | |
| is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False) | |
| active_experts = NaturalNumber(label="Active Experts", value=1, interactive=False, elem_classes="disabled-field") | |
| total_experts = NaturalNumber(label="Total Experts", value=1, interactive=False, elem_classes="disabled-field") | |
| weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True) | |
| # Toggle expert fields interactivity based on MoE checkbox | |
| is_moe.change( | |
| fn=lambda x: [ | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) | |
| ], | |
| inputs=is_moe, | |
| outputs=[active_experts, total_experts] | |
| ) | |
| presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Llama3 8B", interactive=True) | |
| # Populate model parameters when preset is selected | |
| def populate_from_preset(preset_name): | |
| if preset_name and preset_name in DEFAULTS: | |
| model = DEFAULTS[preset_name] | |
| return [ | |
| gr.update(value=model.num_layers), | |
| gr.update(value=model.vocab_size), | |
| gr.update(value=model.hidden_dim), | |
| gr.update(value=model.intermediate_size), | |
| gr.update(value=model.is_moe), | |
| gr.update(value=model.active_experts, interactive=model.is_moe), | |
| gr.update(value=model.total_experts, interactive=model.is_moe), | |
| gr.update(value=model.weight_tied_embeddings) | |
| ] | |
| return [gr.update() for _ in range(8)] | |
| # Switch to "Custom" when user manually edits values | |
| def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset): | |
| # Don't switch to custom if a preset is being applied | |
| if current_preset and current_preset in DEFAULTS: | |
| model = DEFAULTS[current_preset] | |
| # Check if current values match the preset exactly | |
| if (layers_val == model.num_layers and | |
| vocab_val == model.vocab_size and | |
| hidden_val == model.hidden_dim and | |
| intermediate_val == model.intermediate_size and | |
| is_moe_val == model.is_moe and | |
| active_experts_val == model.active_experts and | |
| total_experts_val == model.total_experts and | |
| weight_tied_val == model.weight_tied_embeddings): | |
| return gr.update() # Keep current preset | |
| return gr.update(value="Custom") | |
| presets.change( | |
| fn=populate_from_preset, | |
| inputs=presets, | |
| outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings] | |
| ) | |
| # Add change listeners to all model parameter inputs | |
| for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]: | |
| input_component.change( | |
| fn=switch_to_custom, | |
| inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets], | |
| outputs=presets | |
| ) | |
| return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings | |
| def create_training_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Training Config") | |
| seq_len = NaturalNumber(label="Sequence Length", value=4096) | |
| batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=1) | |
| with gr.Row(): | |
| gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=True) | |
| grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False) | |
| precision = gr.Dropdown(DType.values(), label="Precision", value=DType.BF16.value, interactive=True) | |
| mixed_precision = gr.Checkbox(label="Mixed Precision", value=False) | |
| param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") | |
| reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") | |
| # Toggle dtype fields interactivity based on mixed precision checkbox | |
| mixed_precision.change( | |
| fn=lambda x: [ | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), | |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) | |
| ], | |
| inputs=mixed_precision, | |
| outputs=[param_dtype, reduce_dtype] | |
| ) | |
| return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype | |
| def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype): | |
| # Create state objects | |
| model_config = Model( | |
| vocab_size=int(vocab), | |
| num_layers=int(layers), | |
| hidden_dim=int(hidden), | |
| intermediate_size=int(intermediate), | |
| weight_tied_embeddings=weight_tied_embeddings, | |
| active_experts=int(active_experts), | |
| total_experts=int(total_experts), | |
| is_moe=is_moe | |
| ) | |
| parallelism_config = Parallelism( | |
| tensor_parallelism=int(tp), | |
| pipeline_parallelism=int(pp), | |
| context_parallelism=int(cp), | |
| expert_parallelism=int(ep), | |
| fsdp_enabled=fsdp_enabled, | |
| fsdp_parallelism=int(fsdp_parallelism), | |
| fsdp_strategy=fsdp_strategy | |
| ) | |
| training_config = Training( | |
| sequence_length=int(seq_len), | |
| batch_size=int(batch_size), | |
| gradient_checkpointing=gradient_checkpointing, | |
| grad_accumulation=grad_accumulation, | |
| precision=DType(precision), | |
| mixed_precision=mixed_precision, | |
| param_dtype=DType(param_dtype), | |
| reduce_dtype=DType(reduce_dtype) | |
| ) | |
| # Calculate different memory components | |
| calc = MemoryCalculation(model_config, parallelism_config, training_config) | |
| # Get all memory calculations | |
| param_memory = calc.calculate_parameter_memory() | |
| activation_memory = calc.calculate_activation_memory() | |
| gradient_memory = calc.calculate_gradient_memory() | |
| optimizer_memory = calc.calculate_optimizer_memory() | |
| # Calculate total memory | |
| total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory | |
| # Round to 1 decimal place for display | |
| param_gb = round(param_memory / 1e9, 1) | |
| activation_gb = round(activation_memory / 1e9, 1) | |
| gradient_gb = round(gradient_memory / 1e9, 1) | |
| optimizer_gb = round(optimizer_memory / 1e9, 1) | |
| total_gb = round(total_memory / 1e9, 1) | |
| # Create DataFrame for stacked bar plot | |
| # Start with stacked total bar, then add individual bars | |
| individual_data = [] | |
| # Stacked total bar first - create separate rows for each component within total | |
| for mem_type, gb_val in [ | |
| ('Activation', activation_gb), | |
| ('Optimizer', optimizer_gb), | |
| ('Gradient', gradient_gb), | |
| ('Parameter', param_gb) | |
| ]: | |
| individual_data.append({ | |
| 'Component': f'Total Memory\n{total_gb} GB', | |
| 'Memory (GB)': gb_val, | |
| 'Type': mem_type | |
| }) | |
| # Individual component bars | |
| for component, gb_val, mem_type in [ | |
| (f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'), | |
| (f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'), | |
| (f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'), | |
| (f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation') | |
| ]: | |
| individual_data.append({ | |
| 'Component': component, | |
| 'Memory (GB)': gb_val, | |
| 'Type': mem_type | |
| }) | |
| memory_data = pd.DataFrame(individual_data) | |
| return gr.BarPlot( | |
| value=memory_data, | |
| x="Component", | |
| y="Memory (GB)", | |
| color="Type", | |
| title="LLM Memory Usage Breakdown", | |
| container=False, | |
| y_lim=[0, None], | |
| sort=[ | |
| f'Total Memory\n{total_gb} GB', | |
| f'Parameter Memory\n{param_gb} GB', | |
| f'Gradient Memory\n{gradient_gb} GB', | |
| f'Optimizer Memory\n{optimizer_gb} GB', | |
| f'Activation Memory\n{activation_gb} GB' | |
| ] | |
| ) | |
| css = """ | |
| /* Style for disabled components to make them visually obvious */ | |
| .disabled-field input, | |
| .disabled-field select, | |
| .disabled-field textarea { | |
| opacity: 0.4 !important; | |
| background-color: #f5f5f5 !important; | |
| color: #999 !important; | |
| cursor: not-allowed !important; | |
| text-decoration: line-through; | |
| } | |
| .disabled-field label { | |
| opacity: 0.5 !important; | |
| color: #999 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme='Default', css=css) as demo: | |
| with gr.Column(): | |
| gr.Markdown("# LLM Training Memory Visualizer") | |
| gr.Markdown("<sub>🔧 Built by [Ruben Aghayan](https://www.linkedin.com/in/ruben-aghayan-37885690/)</sub>") | |
| gr.Markdown("---") | |
| gr.Markdown(INSTRUCTIONS) | |
| with gr.Row(equal_height=True): | |
| tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block() | |
| layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block() | |
| seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block() | |
| calculate_button = gr.Button("Calculate") | |
| output = gr.BarPlot(label="Memory Usage Breakdown") | |
| calculate_button.click( | |
| fn=calculate, | |
| inputs=[ | |
| tp, | |
| pp, | |
| cp, | |
| ep, | |
| fsdp_enabled, | |
| fsdp_parallelism, | |
| fsdp_strategy, | |
| layers, | |
| vocab, | |
| hidden, | |
| intermediate, | |
| active_experts, | |
| total_experts, | |
| is_moe, | |
| weight_tied_embeddings, | |
| seq_len, | |
| batch_size, | |
| gradient_checkpointing, | |
| grad_accumulation, | |
| precision, | |
| mixed_precision, | |
| param_dtype, | |
| reduce_dtype, | |
| ], | |
| outputs=output, | |
| ) | |
| gr.Markdown("# Details") | |
| with gr.Row(): | |
| gr.Markdown(LIMITATIONS) | |
| gr.Markdown(DETAILS) | |
| gr.Markdown("# Validation") | |
| gr.Markdown(ACCURACY) | |
| demo.launch(share=True) | |