Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| def calculate_memory_components( | |
| hidden_size, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings | |
| ): | |
| # Calculate base components first | |
| num_hidden_layers_in_pp = num_layers // pp | |
| # Model BF16 calculation | |
| vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1) | |
| layer_params = ( | |
| (hidden_size * 3 * hidden_size) # qkv_proj | |
| + (hidden_size * hidden_size) # out_proj | |
| + (hidden_size * 2 * intermediate_size) # gate_up_proj | |
| + (intermediate_size * hidden_size) # down_proj | |
| ) | |
| model_bf16 = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp | |
| # Other components | |
| dp_if_zero = 1 if zero_stage == 0 else dp | |
| fp32_params = 2 * model_bf16 | |
| fp32_grads = 2 * model_bf16 | |
| optimstates = 4 * model_bf16 | |
| use_ddp = zero_stage == 0 and dp > 1 | |
| ddp_grads_buffers = model_bf16 if use_ddp else 0 | |
| overhead = 72 + 32 * mbs | |
| # Activations | |
| decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10) | |
| if pp > 1: | |
| activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib | |
| else: | |
| cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp | |
| activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy | |
| # Calculate aggregate metrics | |
| memory_usage_after_optimstates = ( | |
| model_bf16 + | |
| fp32_params/dp_if_zero + | |
| fp32_grads + | |
| optimstates/dp_if_zero + | |
| ddp_grads_buffers + | |
| overhead | |
| ) | |
| memory_usage_before_optimstates = ( | |
| model_bf16 + | |
| fp32_params/dp_if_zero + | |
| fp32_grads + | |
| ddp_grads_buffers | |
| ) | |
| memory_usage_peak_tbi = ( | |
| model_bf16 + | |
| fp32_params/dp_if_zero + | |
| fp32_grads + | |
| optimstates/dp_if_zero + | |
| ddp_grads_buffers + | |
| overhead + | |
| activs | |
| ) | |
| return { | |
| "Components": { | |
| "Model BF16": model_bf16, | |
| "FP32 Parameters": fp32_params/dp_if_zero, | |
| "FP32 Gradients": fp32_grads, | |
| "Optimizer States": optimstates/dp_if_zero, | |
| "DDP Gradient Buffers": ddp_grads_buffers, | |
| "Overhead": overhead, | |
| "Activations": activs | |
| }, | |
| "Aggregates": { | |
| "Memory Before Optimizer States": memory_usage_before_optimstates, | |
| "Memory After Optimizer States": memory_usage_after_optimstates, | |
| "Peak Memory (TBI)": memory_usage_peak_tbi | |
| } | |
| } | |
| def plot_memory_breakdown( | |
| hidden_size, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings | |
| ): | |
| results = calculate_memory_components( | |
| hidden_size, num_layers, vocab_size, intermediate_size, | |
| seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
| tie_word_embeddings | |
| ) | |
| # Create figure for components plot | |
| plt.close('all') | |
| fig1 = plt.figure(figsize=(10, 6)) | |
| ax1 = fig1.add_subplot(1, 1, 1) | |
| # Plot components | |
| components = results["Components"] | |
| names = list(components.keys()) | |
| values = list(components.values()) | |
| bars1 = ax1.bar(range(len(components)), values) | |
| # Add value labels with better positioning | |
| for bar in bars1: | |
| height = bar.get_height() | |
| ax1.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.1f} MiB', | |
| ha='center', va='bottom', | |
| rotation=0) # Remove rotation for better readability | |
| # Customize the first plot | |
| ax1.set_xticks(range(len(components))) | |
| ax1.set_xticklabels(names, rotation=45, ha='right') | |
| ax1.set_ylabel('Memory (MiB)') | |
| ax1.set_title('Memory Component Breakdown', pad=20) | |
| plt.tight_layout() | |
| # Create figure for aggregates plot | |
| fig2 = plt.figure(figsize=(10, 6)) | |
| ax2 = fig2.add_subplot(1, 1, 1) | |
| # Plot aggregate metrics | |
| aggregates = results["Aggregates"] | |
| names = list(aggregates.keys()) | |
| values = list(aggregates.values()) | |
| bars2 = ax2.bar(range(len(aggregates)), values, color='orange') | |
| # Add value labels | |
| for bar in bars2: | |
| height = bar.get_height() | |
| ax2.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.1f} MiB', | |
| ha='center', va='bottom') | |
| # Customize the second plot | |
| ax2.set_xticks(range(len(aggregates))) | |
| ax2.set_xticklabels(names, rotation=45, ha='right') | |
| ax2.set_ylabel('Memory (MiB)') | |
| ax2.set_title('Aggregate Memory Metrics', pad=20) | |
| # Adjust layout to prevent text overlap | |
| plt.tight_layout() | |
| return fig1, fig2 | |