Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import functools | |
| def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size): | |
| if pp_size == 1: | |
| return num_layers | |
| # Get list of pipeline blocks and their costs | |
| pipeline_blocks = [] | |
| block_costs = [] | |
| # Embedding layer (treated as zero cost in the original implementation) | |
| pipeline_blocks.append("embedding") | |
| block_costs.append(0) | |
| # Decoder layers | |
| decoder_cost = (4 * num_attention_heads * (hidden_size//num_attention_heads) * hidden_size + | |
| 3 * intermediate_size * hidden_size) | |
| for _ in range(num_layers): | |
| pipeline_blocks.append("decoder") | |
| block_costs.append(decoder_cost) | |
| # LM head | |
| pipeline_blocks.append("lm_head") | |
| block_costs.append(vocab_size * hidden_size) | |
| # Now follow the same logic as the original code | |
| total_cost = sum(block_costs) | |
| target_cost_per_rank = total_cost / pp_size | |
| blocks_in_rank0 = 0 | |
| current_cost = 0 | |
| for block_idx, block_cost in enumerate(block_costs): | |
| current_cost += block_cost | |
| blocks_in_rank0 += 1 | |
| # Check if we should move to next rank | |
| remaining_ranks = pp_size - 1 # -1 because we're calculating for rank 0 | |
| remaining_nonzero_blocks = sum(1 for c in block_costs[block_idx+1:] if c > 0) | |
| if (remaining_ranks > 0 and remaining_nonzero_blocks <= remaining_ranks) or (current_cost >= target_cost_per_rank): | |
| break | |
| num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer | |
| return num_hidden_layers_in_pp | |
| def calculate_memory_components( | |
| hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings, full_checkpointing=False | |
| ): | |
| # Calculate base components first | |
| if pp == 1: | |
| num_hidden_layers_in_pp = num_layers | |
| else: | |
| num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp) | |
| # Model BF16 calculation | |
| vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1) | |
| layer_params = ( | |
| (hidden_size * hidden_size * (1 + 2*num_key_value_heads/num_attention_heads)) # qkv_proj | |
| + (hidden_size * hidden_size) # out_proj | |
| + (hidden_size * 2 * intermediate_size) # gate_up_proj | |
| + (intermediate_size * hidden_size) # down_proj | |
| ) | |
| model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp | |
| # Calculate number of parameters in billions | |
| num_params_in_B = (vocab_embeddings + num_layers*layer_params) / 1e9 | |
| # Adjust model components based on ZeRO stage | |
| if zero_stage == 3: | |
| # In ZeRO-3, model parameters are sharded across dp ranks | |
| model_bf16 = model_bf16_full / dp | |
| fp32_params = 2 * model_bf16 | |
| fp32_grads = 2 * model_bf16 | |
| optimstates = 4 * model_bf16 | |
| # Additional communication buffers for ZeRO-3 | |
| zero3_buffers = 2 * model_bf16 # For parameter gathering during forward/backward | |
| else: | |
| # For ZeRO-0/1/2 | |
| dp_if_zero = 1 if zero_stage == 0 else dp | |
| model_bf16 = model_bf16_full | |
| fp32_params = 2 * model_bf16 / dp_if_zero | |
| fp32_grads = 2 * model_bf16 | |
| optimstates = 4 * model_bf16 / dp_if_zero | |
| zero3_buffers = 0 | |
| use_ddp = zero_stage == 0 and dp > 1 | |
| ddp_grads_buffers = model_bf16 if use_ddp else 0 | |
| overhead = 72 + 32 * mbs | |
| # Activations calculation with FSDP checkpointing support | |
| is_mha = num_key_value_heads == num_attention_heads | |
| decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2) | |
| if pp > 1: | |
| activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib | |
| else: | |
| cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp | |
| base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy | |
| # Apply activation reduction for FSDP checkpointing in ZeRO-3 | |
| if zero_stage == 3 and full_checkpointing: | |
| activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing | |
| else: | |
| activs = base_activs | |
| # Calculate aggregate metrics | |
| memory_usage_after_optimstates = ( | |
| model_bf16 + | |
| fp32_params + | |
| fp32_grads + | |
| optimstates + | |
| ddp_grads_buffers + | |
| zero3_buffers + | |
| overhead | |
| ) | |
| memory_usage_before_optimstates = ( | |
| model_bf16 + | |
| fp32_params + | |
| fp32_grads + | |
| ddp_grads_buffers + | |
| zero3_buffers | |
| ) | |
| memory_usage_peak_tbi = ( | |
| model_bf16 + | |
| fp32_params + | |
| fp32_grads + | |
| optimstates + | |
| ddp_grads_buffers + | |
| zero3_buffers + | |
| overhead + | |
| activs | |
| ) | |
| return { | |
| "Components": { | |
| "Model BF16": model_bf16, | |
| "FP32 Parameters": fp32_params, | |
| "FP32 Gradients": fp32_grads, | |
| "Optimizer States": optimstates, | |
| "DDP Gradient Buffers": ddp_grads_buffers, | |
| "ZeRO-3 Buffers": zero3_buffers, | |
| "Overhead": overhead, | |
| "Activations": activs, | |
| }, | |
| "Aggregates": { | |
| "Memory Before Optimizer States": memory_usage_before_optimstates, | |
| "Memory After Optimizer States": memory_usage_after_optimstates, | |
| "Peak Memory (TBI)": memory_usage_peak_tbi | |
| } | |
| } | |
| def plot_memory_breakdown( | |
| hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings, full_checkpointing=False | |
| ): | |
| results = calculate_memory_components( | |
| hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings, full_checkpointing | |
| ) | |
| memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"] | |
| # Create figure for components plot | |
| plt.close('all') | |
| fig1 = plt.figure(figsize=(10, 5)) | |
| ax1 = fig1.add_subplot(1, 1, 1) | |
| # Plot components | |
| components = results["Components"] | |
| names = list(components.keys()) | |
| values = list(components.values()) | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(components))) | |
| color_map = dict(zip(names, colors)) | |
| bars1 = ax1.bar(range(len(components)), values, color=colors) | |
| # Add value labels with better positioning | |
| for bar in bars1: | |
| height = bar.get_height() | |
| ax1.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.1f} MiB', | |
| ha='center', va='bottom', | |
| rotation=0) # Remove rotation for better readability | |
| # Customize the first plot | |
| ax1.set_xticks(range(len(components))) | |
| ax1.set_xticklabels(names, rotation=45, ha='right') | |
| ax1.set_ylabel('Memory (MiB)') | |
| ax1.set_title('Memory Component Breakdown', pad=20) | |
| plt.tight_layout() | |
| # Create figure for timeline plot | |
| fig2 = plt.figure(figsize=(10, 6)) | |
| ax2 = fig2.add_subplot(1, 1, 1) | |
| # Define timeline steps and their components | |
| c = results["Components"] | |
| timeline_steps = { | |
| "Model Init": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ], | |
| "Gradient Accumulator Init": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ("FP32 Parameters", c["FP32 Parameters"]), | |
| ("FP32 Gradients", c["FP32 Gradients"]) | |
| ], | |
| "Fwd-Bwd Peak": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ("FP32 Parameters", c["FP32 Parameters"]), | |
| ("FP32 Gradients", c["FP32 Gradients"]), | |
| ("Activations", c["Activations"]) | |
| ], | |
| "Optimizer Step": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ("FP32 Parameters", c["FP32 Parameters"]), | |
| ("FP32 Gradients", c["FP32 Gradients"]), | |
| ("Optimizer States", c["Optimizer States"]) | |
| ], | |
| "2nd Fwd-Bwd Peak": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ("FP32 Parameters", c["FP32 Parameters"]), | |
| ("FP32 Gradients", c["FP32 Gradients"]), | |
| ("Optimizer States", c["Optimizer States"]), | |
| ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), | |
| ("Activations", c["Activations"]) | |
| ], | |
| "2nd Optimizer Step": [ | |
| ("Model BF16", c["Model BF16"]), | |
| ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), | |
| ("FP32 Parameters", c["FP32 Parameters"]), | |
| ("FP32 Gradients", c["FP32 Gradients"]), | |
| ("Optimizer States", c["Optimizer States"]), | |
| ("DDP Gradient Buffers", c["DDP Gradient Buffers"]) | |
| ] | |
| } | |
| # Plot timeline | |
| x = range(len(timeline_steps)) | |
| bottom = np.zeros(len(timeline_steps)) | |
| for component in c.keys(): | |
| heights = [] | |
| for step_components in timeline_steps.values(): | |
| height = 0 | |
| for comp_name, comp_value in step_components: | |
| if comp_name == component: | |
| height = comp_value | |
| heights.append(height) | |
| ax2.bar(x, heights, bottom=bottom, label=component, color=color_map[component]) | |
| bottom += heights | |
| # Customize the timeline plot | |
| ax2.set_xticks(x) | |
| ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right') | |
| ax2.set_ylabel('Memory (MiB)') | |
| ax2.set_title('Memory Timeline', pad=20) | |
| # Add total memory labels on top of each bar | |
| for i, total in enumerate(bottom): | |
| ax2.text(i, total, f'{total:.1f} MiB', ha='center', va='bottom') | |
| # Adjust layout | |
| plt.tight_layout() | |
| # Set y-axis limit | |
| max_y_value = max(bottom) | |
| ax2.set_ylim(0, max(80000, max_y_value)) | |
| # Add legend below the plot | |
| # plt.subplots_adjust(bottom=0.8) | |
| ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.5), ncol=3) | |
| return fig1, fig2, memory_usage_peak_tbi | |