import matplotlib.pyplot as plt def calculate_memory_components( hidden_size, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings ): # Calculate base components first num_hidden_layers_in_pp = num_layers // pp # Model BF16 calculation vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1) layer_params = ( (hidden_size * 3 * hidden_size) # qkv_proj + (hidden_size * hidden_size) # out_proj + (hidden_size * 2 * intermediate_size) # gate_up_proj + (intermediate_size * hidden_size) # down_proj ) model_bf16 = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp # Other components dp_if_zero = 1 if zero_stage == 0 else dp fp32_params = 2 * model_bf16 fp32_grads = 2 * model_bf16 optimstates = 4 * model_bf16 use_ddp = zero_stage == 0 and dp > 1 ddp_grads_buffers = model_bf16 if use_ddp else 0 overhead = 72 + 32 * mbs # Activations decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10) if pp > 1: activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib else: cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy # Calculate aggregate metrics memory_usage_after_optimstates = ( model_bf16 + fp32_params/dp_if_zero + fp32_grads + optimstates/dp_if_zero + ddp_grads_buffers + overhead ) memory_usage_before_optimstates = ( model_bf16 + fp32_params/dp_if_zero + fp32_grads + ddp_grads_buffers ) memory_usage_peak_tbi = ( model_bf16 + fp32_params/dp_if_zero + fp32_grads + optimstates/dp_if_zero + ddp_grads_buffers + overhead + activs ) return { "Components": { "Model BF16": model_bf16, "FP32 Parameters": fp32_params/dp_if_zero, "FP32 Gradients": fp32_grads, "Optimizer States": optimstates/dp_if_zero, "DDP Gradient Buffers": ddp_grads_buffers, "Overhead": overhead, "Activations": activs }, "Aggregates": { "Memory Before Optimizer States": memory_usage_before_optimstates, "Memory After Optimizer States": memory_usage_after_optimstates, "Peak Memory (TBI)": memory_usage_peak_tbi } } def plot_memory_breakdown( hidden_size, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings ): results = calculate_memory_components( hidden_size, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings ) # Create figure for components plot plt.close('all') fig1 = plt.figure(figsize=(10, 6)) ax1 = fig1.add_subplot(1, 1, 1) # Plot components components = results["Components"] names = list(components.keys()) values = list(components.values()) bars1 = ax1.bar(range(len(components)), values) # Add value labels with better positioning for bar in bars1: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f} MiB', ha='center', va='bottom', rotation=0) # Remove rotation for better readability # Customize the first plot ax1.set_xticks(range(len(components))) ax1.set_xticklabels(names, rotation=45, ha='right') ax1.set_ylabel('Memory (MiB)') ax1.set_title('Memory Component Breakdown', pad=20) plt.tight_layout() # Create figure for aggregates plot fig2 = plt.figure(figsize=(10, 6)) ax2 = fig2.add_subplot(1, 1, 1) # Plot aggregate metrics aggregates = results["Aggregates"] names = list(aggregates.keys()) values = list(aggregates.values()) bars2 = ax2.bar(range(len(aggregates)), values, color='orange') # Add value labels for bar in bars2: height = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f} MiB', ha='center', va='bottom') # Customize the second plot ax2.set_xticks(range(len(aggregates))) ax2.set_xticklabels(names, rotation=45, ha='right') ax2.set_ylabel('Memory (MiB)') ax2.set_title('Aggregate Memory Metrics', pad=20) # Adjust layout to prevent text overlap plt.tight_layout() return fig1, fig2