Spaces:
Running
Running
Commit ·
9d879a4
1
Parent(s): 9a970ef
EXPERIMENTAL: add zero2 zero3
Browse files
app.py
CHANGED
|
@@ -107,7 +107,7 @@ with gr.Blocks() as demo:
|
|
| 107 |
tp = gr.Number(1, label="Tensor Parallelism")
|
| 108 |
pp = gr.Number(1, label="Pipeline Parallelism")
|
| 109 |
dp = gr.Number(1, label="Data Parallelism")
|
| 110 |
-
zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
|
| 111 |
|
| 112 |
manual_submit = gr.Button("Calculate Memory (Manual Input)")
|
| 113 |
with gr.Column(scale=2):
|
|
|
|
| 107 |
tp = gr.Number(1, label="Tensor Parallelism")
|
| 108 |
pp = gr.Number(1, label="Pipeline Parallelism")
|
| 109 |
dp = gr.Number(1, label="Data Parallelism")
|
| 110 |
+
zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
|
| 111 |
|
| 112 |
manual_submit = gr.Button("Calculate Memory (Manual Input)")
|
| 113 |
with gr.Column(scale=2):
|
utils.py
CHANGED
|
@@ -57,7 +57,6 @@ def calculate_memory_components(
|
|
| 57 |
if pp == 1:
|
| 58 |
num_hidden_layers_in_pp = num_layers
|
| 59 |
else:
|
| 60 |
-
# num_hidden_layers_in_pp = num_layers // pp
|
| 61 |
num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
|
| 62 |
|
| 63 |
# Model BF16 calculation
|
|
@@ -70,13 +69,26 @@ def calculate_memory_components(
|
|
| 70 |
+ (intermediate_size * hidden_size) # down_proj
|
| 71 |
)
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
# Other components
|
| 76 |
-
dp_if_zero = 1 if zero_stage == 0 else dp
|
| 77 |
-
fp32_params = 2 * model_bf16
|
| 78 |
-
fp32_grads = 2 * model_bf16
|
| 79 |
-
optimstates = 4 * model_bf16
|
| 80 |
use_ddp = zero_stage == 0 and dp > 1
|
| 81 |
ddp_grads_buffers = model_bf16 if use_ddp else 0
|
| 82 |
overhead = 72 + 32 * mbs
|
|
@@ -84,7 +96,6 @@ def calculate_memory_components(
|
|
| 84 |
# Activations
|
| 85 |
is_mha = num_key_value_heads == num_attention_heads
|
| 86 |
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
|
| 87 |
-
# decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
|
| 88 |
|
| 89 |
if pp > 1:
|
| 90 |
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
|
|
@@ -95,26 +106,29 @@ def calculate_memory_components(
|
|
| 95 |
# Calculate aggregate metrics
|
| 96 |
memory_usage_after_optimstates = (
|
| 97 |
model_bf16 +
|
| 98 |
-
fp32_params
|
| 99 |
fp32_grads +
|
| 100 |
-
optimstates
|
| 101 |
ddp_grads_buffers +
|
|
|
|
| 102 |
overhead
|
| 103 |
)
|
| 104 |
|
| 105 |
memory_usage_before_optimstates = (
|
| 106 |
model_bf16 +
|
| 107 |
-
fp32_params
|
| 108 |
fp32_grads +
|
| 109 |
-
ddp_grads_buffers
|
|
|
|
| 110 |
)
|
| 111 |
|
| 112 |
memory_usage_peak_tbi = (
|
| 113 |
model_bf16 +
|
| 114 |
-
fp32_params
|
| 115 |
fp32_grads +
|
| 116 |
-
optimstates
|
| 117 |
ddp_grads_buffers +
|
|
|
|
| 118 |
overhead +
|
| 119 |
activs
|
| 120 |
)
|
|
@@ -122,10 +136,11 @@ def calculate_memory_components(
|
|
| 122 |
return {
|
| 123 |
"Components": {
|
| 124 |
"Model BF16": model_bf16,
|
| 125 |
-
"FP32 Parameters": fp32_params
|
| 126 |
"FP32 Gradients": fp32_grads,
|
| 127 |
-
"Optimizer States": optimstates
|
| 128 |
"DDP Gradient Buffers": ddp_grads_buffers,
|
|
|
|
| 129 |
"Overhead": overhead,
|
| 130 |
"Activations": activs
|
| 131 |
},
|
|
@@ -189,28 +204,33 @@ def plot_memory_breakdown(
|
|
| 189 |
"Model Init": [
|
| 190 |
("Model BF16", c["Model BF16"]),
|
| 191 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
|
|
|
| 192 |
],
|
| 193 |
"Gradient Accumulator Init": [
|
| 194 |
("Model BF16", c["Model BF16"]),
|
| 195 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
|
|
|
| 196 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 197 |
("FP32 Gradients", c["FP32 Gradients"])
|
| 198 |
],
|
| 199 |
"Fwd-Bwd Peak": [
|
| 200 |
("Model BF16", c["Model BF16"]),
|
| 201 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
|
|
|
| 202 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 203 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 204 |
("Activations", c["Activations"])
|
| 205 |
],
|
| 206 |
"Optimizer Step": [
|
| 207 |
("Model BF16", c["Model BF16"]),
|
|
|
|
| 208 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 209 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 210 |
("Optimizer States", c["Optimizer States"])
|
| 211 |
],
|
| 212 |
"2nd Fwd-Bwd Peak": [
|
| 213 |
("Model BF16", c["Model BF16"]),
|
|
|
|
| 214 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 215 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 216 |
("Optimizer States", c["Optimizer States"]),
|
|
@@ -219,6 +239,7 @@ def plot_memory_breakdown(
|
|
| 219 |
],
|
| 220 |
"2nd Optimizer Step": [
|
| 221 |
("Model BF16", c["Model BF16"]),
|
|
|
|
| 222 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 223 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 224 |
("Optimizer States", c["Optimizer States"]),
|
|
|
|
| 57 |
if pp == 1:
|
| 58 |
num_hidden_layers_in_pp = num_layers
|
| 59 |
else:
|
|
|
|
| 60 |
num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
|
| 61 |
|
| 62 |
# Model BF16 calculation
|
|
|
|
| 69 |
+ (intermediate_size * hidden_size) # down_proj
|
| 70 |
)
|
| 71 |
|
| 72 |
+
model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
|
| 73 |
+
|
| 74 |
+
# Adjust model components based on ZeRO stage
|
| 75 |
+
if zero_stage == 3:
|
| 76 |
+
# In ZeRO-3, model parameters are sharded across dp ranks
|
| 77 |
+
model_bf16 = model_bf16_full / dp
|
| 78 |
+
fp32_params = 2 * model_bf16
|
| 79 |
+
fp32_grads = 2 * model_bf16
|
| 80 |
+
optimstates = 4 * model_bf16
|
| 81 |
+
# Additional communication buffers for ZeRO-3
|
| 82 |
+
zero3_buffers = 2 * model_bf16 # For parameter gathering during forward/backward
|
| 83 |
+
else:
|
| 84 |
+
# For ZeRO-0/1/2
|
| 85 |
+
dp_if_zero = 1 if zero_stage == 0 else dp
|
| 86 |
+
model_bf16 = model_bf16_full
|
| 87 |
+
fp32_params = 2 * model_bf16 / dp_if_zero
|
| 88 |
+
fp32_grads = 2 * model_bf16
|
| 89 |
+
optimstates = 4 * model_bf16 / dp_if_zero
|
| 90 |
+
zero3_buffers = 0
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
use_ddp = zero_stage == 0 and dp > 1
|
| 93 |
ddp_grads_buffers = model_bf16 if use_ddp else 0
|
| 94 |
overhead = 72 + 32 * mbs
|
|
|
|
| 96 |
# Activations
|
| 97 |
is_mha = num_key_value_heads == num_attention_heads
|
| 98 |
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
|
|
|
|
| 99 |
|
| 100 |
if pp > 1:
|
| 101 |
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
|
|
|
|
| 106 |
# Calculate aggregate metrics
|
| 107 |
memory_usage_after_optimstates = (
|
| 108 |
model_bf16 +
|
| 109 |
+
fp32_params +
|
| 110 |
fp32_grads +
|
| 111 |
+
optimstates +
|
| 112 |
ddp_grads_buffers +
|
| 113 |
+
zero3_buffers +
|
| 114 |
overhead
|
| 115 |
)
|
| 116 |
|
| 117 |
memory_usage_before_optimstates = (
|
| 118 |
model_bf16 +
|
| 119 |
+
fp32_params +
|
| 120 |
fp32_grads +
|
| 121 |
+
ddp_grads_buffers +
|
| 122 |
+
zero3_buffers
|
| 123 |
)
|
| 124 |
|
| 125 |
memory_usage_peak_tbi = (
|
| 126 |
model_bf16 +
|
| 127 |
+
fp32_params +
|
| 128 |
fp32_grads +
|
| 129 |
+
optimstates +
|
| 130 |
ddp_grads_buffers +
|
| 131 |
+
zero3_buffers +
|
| 132 |
overhead +
|
| 133 |
activs
|
| 134 |
)
|
|
|
|
| 136 |
return {
|
| 137 |
"Components": {
|
| 138 |
"Model BF16": model_bf16,
|
| 139 |
+
"FP32 Parameters": fp32_params,
|
| 140 |
"FP32 Gradients": fp32_grads,
|
| 141 |
+
"Optimizer States": optimstates,
|
| 142 |
"DDP Gradient Buffers": ddp_grads_buffers,
|
| 143 |
+
"ZeRO-3 Buffers": zero3_buffers,
|
| 144 |
"Overhead": overhead,
|
| 145 |
"Activations": activs
|
| 146 |
},
|
|
|
|
| 204 |
"Model Init": [
|
| 205 |
("Model BF16", c["Model BF16"]),
|
| 206 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
| 207 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 208 |
],
|
| 209 |
"Gradient Accumulator Init": [
|
| 210 |
("Model BF16", c["Model BF16"]),
|
| 211 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
| 212 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 213 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 214 |
("FP32 Gradients", c["FP32 Gradients"])
|
| 215 |
],
|
| 216 |
"Fwd-Bwd Peak": [
|
| 217 |
("Model BF16", c["Model BF16"]),
|
| 218 |
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
|
| 219 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 220 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 221 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 222 |
("Activations", c["Activations"])
|
| 223 |
],
|
| 224 |
"Optimizer Step": [
|
| 225 |
("Model BF16", c["Model BF16"]),
|
| 226 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 227 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 228 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 229 |
("Optimizer States", c["Optimizer States"])
|
| 230 |
],
|
| 231 |
"2nd Fwd-Bwd Peak": [
|
| 232 |
("Model BF16", c["Model BF16"]),
|
| 233 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 234 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 235 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 236 |
("Optimizer States", c["Optimizer States"]),
|
|
|
|
| 239 |
],
|
| 240 |
"2nd Optimizer Step": [
|
| 241 |
("Model BF16", c["Model BF16"]),
|
| 242 |
+
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
|
| 243 |
("FP32 Parameters", c["FP32 Parameters"]),
|
| 244 |
("FP32 Gradients", c["FP32 Gradients"]),
|
| 245 |
("Optimizer States", c["Optimizer States"]),
|