Spaces:
Running
Running
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import yaml | |
| import json | |
| from pathlib import Path | |
| import io | |
| from utils import calculate_memory_components, plot_memory_breakdown | |
| def load_config_from_content(content): | |
| try: | |
| # Try parsing as JSON first | |
| try: | |
| config = json.loads(content) | |
| # Check if this is a multimodal model with text_config | |
| if 'text_config' in config: | |
| # Use text_config for model parameters | |
| text_config = config['text_config'] | |
| return { | |
| 'hidden_size': text_config['hidden_size'], | |
| 'num_layers': text_config['num_hidden_layers'], | |
| 'vocab_size': config.get('vocab_size', 256000), # Default for multimodal models | |
| 'intermediate_size': text_config['intermediate_size'], | |
| 'seq_len': 2048, # Default value since not in config | |
| 'mbs': 1, # Default value | |
| 'batch_accum': 1, # Default value | |
| 'tp': 1, # Default value | |
| 'pp': 1, # Default value | |
| 'dp': 1, # Default value | |
| 'zero_stage': 0, # Default value | |
| 'tie_word_embeddings': config.get('tie_word_embeddings', True), | |
| 'num_attention_heads': text_config['num_attention_heads'], | |
| 'num_key_value_heads': text_config.get('num_key_value_heads', text_config['num_attention_heads']), | |
| 'full_checkpointing': False # Default value | |
| } | |
| else: | |
| # Original code for non-multimodal models | |
| return { | |
| 'hidden_size': config['hidden_size'], | |
| 'num_layers': config['num_hidden_layers'], | |
| 'vocab_size': config['vocab_size'], | |
| 'intermediate_size': config['intermediate_size'], | |
| 'seq_len': 2048, # Default value since not in config | |
| 'mbs': 1, # Default value | |
| 'batch_accum': 1, # Default value | |
| 'tp': 1, # Default value | |
| 'pp': 1, # Default value | |
| 'dp': 1, # Default value | |
| 'zero_stage': 0, # Default value | |
| 'tie_word_embeddings': config.get('tie_word_embeddings', True), | |
| 'num_attention_heads': config['num_attention_heads'], | |
| 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']), | |
| 'full_checkpointing': False # Default value | |
| } | |
| except json.JSONDecodeError: | |
| # If not JSON, try YAML | |
| config = yaml.safe_load(content) | |
| # Extract relevant parameters from YAML config | |
| model_config = config['model']['model_config'] | |
| parallelism = config['parallelism'] | |
| tokens = config['tokens'] | |
| optimizer = config['optimizer'] | |
| return { | |
| 'hidden_size': model_config['hidden_size'], | |
| 'num_layers': model_config['num_hidden_layers'], | |
| 'vocab_size': model_config['vocab_size'], | |
| 'intermediate_size': model_config['intermediate_size'], | |
| 'seq_len': tokens['sequence_length'], | |
| 'mbs': tokens['micro_batch_size'], | |
| 'batch_accum': tokens['batch_accumulation_per_replica'], | |
| 'tp': parallelism['tp'], | |
| 'pp': parallelism['pp'], | |
| 'dp': parallelism['dp'], | |
| 'zero_stage': optimizer['zero_stage'], | |
| 'tie_word_embeddings': model_config['tie_word_embeddings'], | |
| 'num_attention_heads': model_config['num_attention_heads'], | |
| 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']), | |
| 'full_checkpointing': optimizer.get('full_checkpointing', False) # Renamed from fsdp_checkpointing | |
| } | |
| except Exception as e: | |
| raise gr.Error(f"Error parsing configuration: {str(e)}") | |
| def load_config_from_yaml_file(yaml_path): | |
| if not yaml_path: | |
| return None | |
| with open(yaml_path.name, 'r') as f: | |
| return load_config_from_content(f.read()) | |
| def format_config_display(config): | |
| if not config: | |
| return "No configuration loaded" | |
| # Calculate number of parameters | |
| vocab_embeddings = config['vocab_size'] * config['hidden_size'] * (1 if config['tie_word_embeddings'] else 2) | |
| layer_params = ( | |
| (config['hidden_size'] * config['hidden_size'] * (1 + 2*config['num_key_value_heads']/config['num_attention_heads'])) # qkv_proj | |
| + (config['hidden_size'] * config['hidden_size']) # out_proj | |
| + (config['hidden_size'] * 2 * config['intermediate_size']) # gate_up_proj | |
| + (config['intermediate_size'] * config['hidden_size']) # down_proj | |
| ) | |
| total_params = (vocab_embeddings + config['num_layers'] * layer_params) | |
| params_billions = total_params / 1_000_000_000 | |
| sections = { | |
| "Model Architecture": [ | |
| "hidden_size", "num_layers", "vocab_size", | |
| "intermediate_size", "tie_word_embeddings", "num_attention_heads", "num_key_value_heads", | |
| ("num_params", f"{params_billions:.2f}B") # Show params in billions | |
| ], | |
| "Training Configuration": [ | |
| "seq_len", "mbs", "batch_accum" | |
| ], | |
| "Parallelism": [ | |
| "tp", "pp", "dp", "zero_stage", "full_checkpointing" | |
| ] | |
| } | |
| output = "<div style='display: flex;'>" | |
| for section_name, params in sections.items(): | |
| output += f"<div style='flex: 1; padding-right: 20px;'><h3>{section_name}</h3>" | |
| for param in params: | |
| if isinstance(param, tuple): | |
| # Handle custom parameter display | |
| param_name, value = param | |
| output += f"<b>{param_name}</b>: {value}<br>" | |
| else: | |
| value = config.get(param, 'N/A') | |
| output += f"<b>{param}</b>: {value}<br>" | |
| output += "</div>" | |
| output += "</div>" | |
| return output | |
| def process_yaml_and_plot(config): | |
| if not config: | |
| return None, None, "No configuration loaded", None | |
| fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) | |
| oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" | |
| return fig1, fig2, format_config_display(config), oom_prediction | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Configuration Input", open=True): | |
| config_text = gr.Textbox( | |
| label="Paste YAML or JSON configuration", | |
| placeholder="Paste your YAML or JSON configuration here...", | |
| lines=10 | |
| ) | |
| config_submit = gr.Button("Calculate Memory from Config") | |
| with gr.Accordion("Manual Configuration", open=True): | |
| with gr.Accordion("Model Architecture", open=True): | |
| with gr.Row(): | |
| hidden_size = gr.Number(4096, label="Hidden Size") | |
| num_layers = gr.Number(32, label="Number of Layers") | |
| with gr.Row(): | |
| vocab_size = gr.Number(50432, label="Vocabulary Size") | |
| intermediate_size = gr.Number(11008, label="Intermediate Size") | |
| with gr.Row(): | |
| num_attention_heads = gr.Number(32, label="Number of Attention Heads") | |
| num_key_value_heads = gr.Number(32, label="Number of Key Value Heads") | |
| tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings") | |
| with gr.Accordion("Training Configuration", open=True): | |
| with gr.Row(): | |
| seq_len = gr.Number(2048, label="Sequence Length") | |
| mbs = gr.Number(1, label="Micro Batch Size") | |
| batch_accum = gr.Number(1, label="Gradient Accumulation Steps") | |
| with gr.Accordion("Parallelism", open=True): | |
| with gr.Row(): | |
| tp = gr.Number(1, label="Tensor Parallelism") | |
| pp = gr.Number(1, label="Pipeline Parallelism") | |
| dp = gr.Number(1, label="Data Parallelism") | |
| zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage") | |
| full_checkpointing = gr.Checkbox(False, label="Full Activation Checkpointing") | |
| manual_submit = gr.Button("Calculate Memory (Manual Input)") | |
| with gr.Column(scale=2): | |
| config_display = gr.Markdown(label="Configuration Values") | |
| oom_display = gr.Text(label="OOM Prediction") | |
| plot1 = gr.Plot(label="Memory Component Breakdown") | |
| plot2 = gr.Plot(label="Aggregate Memory Metrics") | |
| # Handle config text input | |
| config_submit.click( | |
| lambda x: process_yaml_and_update_ui(load_config_from_content(x) if x else None), | |
| inputs=[config_text], | |
| outputs=[ | |
| plot1, plot2, config_display, oom_display, | |
| hidden_size, num_attention_heads, num_key_value_heads, num_layers, | |
| vocab_size, intermediate_size, seq_len, mbs, batch_accum, | |
| tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing | |
| ] | |
| ) | |
| def process_yaml_and_update_ui(config): | |
| if not config: | |
| return [None, None, "No configuration loaded", None] + [gr.update() for _ in range(14)] | |
| fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) | |
| oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" | |
| # Return values for all outputs including UI updates | |
| return [ | |
| fig1, fig2, | |
| format_config_display(config), | |
| oom_prediction, | |
| # UI component updates | |
| config['hidden_size'], | |
| config['num_attention_heads'], | |
| config['num_key_value_heads'], | |
| config['num_layers'], | |
| config['vocab_size'], | |
| config['intermediate_size'], | |
| config['seq_len'], | |
| config['mbs'], | |
| config['batch_accum'], | |
| config['tp'], | |
| config['pp'], | |
| config['dp'], | |
| config['zero_stage'], | |
| config['tie_word_embeddings'], | |
| config['full_checkpointing'] | |
| ] | |
| # Handle manual input | |
| def manual_input_to_config(*args): | |
| config = { | |
| 'hidden_size': args[0], | |
| 'num_layers': args[3], | |
| 'vocab_size': args[4], | |
| 'intermediate_size': args[5], | |
| 'seq_len': args[6], | |
| 'mbs': args[7], | |
| 'batch_accum': args[8], | |
| 'tp': args[9], | |
| 'pp': args[10], | |
| 'dp': args[11], | |
| 'zero_stage': args[12], | |
| 'tie_word_embeddings': args[13], | |
| 'num_attention_heads': args[1], | |
| 'num_key_value_heads': args[2], | |
| 'full_checkpointing': args[14] # Renamed from fsdp_checkpointing | |
| } | |
| return process_yaml_and_update_ui(config) | |
| manual_submit.click( | |
| manual_input_to_config, | |
| inputs=[ | |
| hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings, full_checkpointing # Renamed from fsdp_checkpointing | |
| ], | |
| outputs=[plot1, plot2, config_display, oom_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |