predict_memory / utils.py
nouamanetazi's picture
nouamanetazi HF Staff
init
5f67cc3
raw
history blame
4.91 kB
import matplotlib.pyplot as plt
def calculate_memory_components(
hidden_size, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings
):
# Calculate base components first
num_hidden_layers_in_pp = num_layers // pp
# Model BF16 calculation
vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1)
layer_params = (
(hidden_size * 3 * hidden_size) # qkv_proj
+ (hidden_size * hidden_size) # out_proj
+ (hidden_size * 2 * intermediate_size) # gate_up_proj
+ (intermediate_size * hidden_size) # down_proj
)
model_bf16 = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
# Other components
dp_if_zero = 1 if zero_stage == 0 else dp
fp32_params = 2 * model_bf16
fp32_grads = 2 * model_bf16
optimstates = 4 * model_bf16
use_ddp = zero_stage == 0 and dp > 1
ddp_grads_buffers = model_bf16 if use_ddp else 0
overhead = 72 + 32 * mbs
# Activations
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
if pp > 1:
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
else:
cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
# Calculate aggregate metrics
memory_usage_after_optimstates = (
model_bf16 +
fp32_params/dp_if_zero +
fp32_grads +
optimstates/dp_if_zero +
ddp_grads_buffers +
overhead
)
memory_usage_before_optimstates = (
model_bf16 +
fp32_params/dp_if_zero +
fp32_grads +
ddp_grads_buffers
)
memory_usage_peak_tbi = (
model_bf16 +
fp32_params/dp_if_zero +
fp32_grads +
optimstates/dp_if_zero +
ddp_grads_buffers +
overhead +
activs
)
return {
"Components": {
"Model BF16": model_bf16,
"FP32 Parameters": fp32_params/dp_if_zero,
"FP32 Gradients": fp32_grads,
"Optimizer States": optimstates/dp_if_zero,
"DDP Gradient Buffers": ddp_grads_buffers,
"Overhead": overhead,
"Activations": activs
},
"Aggregates": {
"Memory Before Optimizer States": memory_usage_before_optimstates,
"Memory After Optimizer States": memory_usage_after_optimstates,
"Peak Memory (TBI)": memory_usage_peak_tbi
}
}
def plot_memory_breakdown(
hidden_size, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings
):
results = calculate_memory_components(
hidden_size, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings
)
# Create figure for components plot
plt.close('all')
fig1 = plt.figure(figsize=(10, 6))
ax1 = fig1.add_subplot(1, 1, 1)
# Plot components
components = results["Components"]
names = list(components.keys())
values = list(components.values())
bars1 = ax1.bar(range(len(components)), values)
# Add value labels with better positioning
for bar in bars1:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f} MiB',
ha='center', va='bottom',
rotation=0) # Remove rotation for better readability
# Customize the first plot
ax1.set_xticks(range(len(components)))
ax1.set_xticklabels(names, rotation=45, ha='right')
ax1.set_ylabel('Memory (MiB)')
ax1.set_title('Memory Component Breakdown', pad=20)
plt.tight_layout()
# Create figure for aggregates plot
fig2 = plt.figure(figsize=(10, 6))
ax2 = fig2.add_subplot(1, 1, 1)
# Plot aggregate metrics
aggregates = results["Aggregates"]
names = list(aggregates.keys())
values = list(aggregates.values())
bars2 = ax2.bar(range(len(aggregates)), values, color='orange')
# Add value labels
for bar in bars2:
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f} MiB',
ha='center', va='bottom')
# Customize the second plot
ax2.set_xticks(range(len(aggregates)))
ax2.set_xticklabels(names, rotation=45, ha='right')
ax2.set_ylabel('Memory (MiB)')
ax2.set_title('Aggregate Memory Metrics', pad=20)
# Adjust layout to prevent text overlap
plt.tight_layout()
return fig1, fig2