nouamanetazi HF Staff commited on
Commit
9d879a4
·
1 Parent(s): 9a970ef

EXPERIMENTAL: add zero2 zero3

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. utils.py +37 -16
app.py CHANGED
@@ -107,7 +107,7 @@ with gr.Blocks() as demo:
107
  tp = gr.Number(1, label="Tensor Parallelism")
108
  pp = gr.Number(1, label="Pipeline Parallelism")
109
  dp = gr.Number(1, label="Data Parallelism")
110
- zero_stage = gr.Radio([0, 1], value=0, label="ZeRO Stage")
111
 
112
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
113
  with gr.Column(scale=2):
 
107
  tp = gr.Number(1, label="Tensor Parallelism")
108
  pp = gr.Number(1, label="Pipeline Parallelism")
109
  dp = gr.Number(1, label="Data Parallelism")
110
+ zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
111
 
112
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
113
  with gr.Column(scale=2):
utils.py CHANGED
@@ -57,7 +57,6 @@ def calculate_memory_components(
57
  if pp == 1:
58
  num_hidden_layers_in_pp = num_layers
59
  else:
60
- # num_hidden_layers_in_pp = num_layers // pp
61
  num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
62
 
63
  # Model BF16 calculation
@@ -70,13 +69,26 @@ def calculate_memory_components(
70
  + (intermediate_size * hidden_size) # down_proj
71
  )
72
 
73
- model_bf16 = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Other components
76
- dp_if_zero = 1 if zero_stage == 0 else dp
77
- fp32_params = 2 * model_bf16
78
- fp32_grads = 2 * model_bf16
79
- optimstates = 4 * model_bf16
80
  use_ddp = zero_stage == 0 and dp > 1
81
  ddp_grads_buffers = model_bf16 if use_ddp else 0
82
  overhead = 72 + 32 * mbs
@@ -84,7 +96,6 @@ def calculate_memory_components(
84
  # Activations
85
  is_mha = num_key_value_heads == num_attention_heads
86
  decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
87
- # decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
88
 
89
  if pp > 1:
90
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
@@ -95,26 +106,29 @@ def calculate_memory_components(
95
  # Calculate aggregate metrics
96
  memory_usage_after_optimstates = (
97
  model_bf16 +
98
- fp32_params/dp_if_zero +
99
  fp32_grads +
100
- optimstates/dp_if_zero +
101
  ddp_grads_buffers +
 
102
  overhead
103
  )
104
 
105
  memory_usage_before_optimstates = (
106
  model_bf16 +
107
- fp32_params/dp_if_zero +
108
  fp32_grads +
109
- ddp_grads_buffers
 
110
  )
111
 
112
  memory_usage_peak_tbi = (
113
  model_bf16 +
114
- fp32_params/dp_if_zero +
115
  fp32_grads +
116
- optimstates/dp_if_zero +
117
  ddp_grads_buffers +
 
118
  overhead +
119
  activs
120
  )
@@ -122,10 +136,11 @@ def calculate_memory_components(
122
  return {
123
  "Components": {
124
  "Model BF16": model_bf16,
125
- "FP32 Parameters": fp32_params/dp_if_zero,
126
  "FP32 Gradients": fp32_grads,
127
- "Optimizer States": optimstates/dp_if_zero,
128
  "DDP Gradient Buffers": ddp_grads_buffers,
 
129
  "Overhead": overhead,
130
  "Activations": activs
131
  },
@@ -189,28 +204,33 @@ def plot_memory_breakdown(
189
  "Model Init": [
190
  ("Model BF16", c["Model BF16"]),
191
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
 
192
  ],
193
  "Gradient Accumulator Init": [
194
  ("Model BF16", c["Model BF16"]),
195
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
 
196
  ("FP32 Parameters", c["FP32 Parameters"]),
197
  ("FP32 Gradients", c["FP32 Gradients"])
198
  ],
199
  "Fwd-Bwd Peak": [
200
  ("Model BF16", c["Model BF16"]),
201
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
 
202
  ("FP32 Parameters", c["FP32 Parameters"]),
203
  ("FP32 Gradients", c["FP32 Gradients"]),
204
  ("Activations", c["Activations"])
205
  ],
206
  "Optimizer Step": [
207
  ("Model BF16", c["Model BF16"]),
 
208
  ("FP32 Parameters", c["FP32 Parameters"]),
209
  ("FP32 Gradients", c["FP32 Gradients"]),
210
  ("Optimizer States", c["Optimizer States"])
211
  ],
212
  "2nd Fwd-Bwd Peak": [
213
  ("Model BF16", c["Model BF16"]),
 
214
  ("FP32 Parameters", c["FP32 Parameters"]),
215
  ("FP32 Gradients", c["FP32 Gradients"]),
216
  ("Optimizer States", c["Optimizer States"]),
@@ -219,6 +239,7 @@ def plot_memory_breakdown(
219
  ],
220
  "2nd Optimizer Step": [
221
  ("Model BF16", c["Model BF16"]),
 
222
  ("FP32 Parameters", c["FP32 Parameters"]),
223
  ("FP32 Gradients", c["FP32 Gradients"]),
224
  ("Optimizer States", c["Optimizer States"]),
 
57
  if pp == 1:
58
  num_hidden_layers_in_pp = num_layers
59
  else:
 
60
  num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
61
 
62
  # Model BF16 calculation
 
69
  + (intermediate_size * hidden_size) # down_proj
70
  )
71
 
72
+ model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
73
+
74
+ # Adjust model components based on ZeRO stage
75
+ if zero_stage == 3:
76
+ # In ZeRO-3, model parameters are sharded across dp ranks
77
+ model_bf16 = model_bf16_full / dp
78
+ fp32_params = 2 * model_bf16
79
+ fp32_grads = 2 * model_bf16
80
+ optimstates = 4 * model_bf16
81
+ # Additional communication buffers for ZeRO-3
82
+ zero3_buffers = 2 * model_bf16 # For parameter gathering during forward/backward
83
+ else:
84
+ # For ZeRO-0/1/2
85
+ dp_if_zero = 1 if zero_stage == 0 else dp
86
+ model_bf16 = model_bf16_full
87
+ fp32_params = 2 * model_bf16 / dp_if_zero
88
+ fp32_grads = 2 * model_bf16
89
+ optimstates = 4 * model_bf16 / dp_if_zero
90
+ zero3_buffers = 0
91
 
 
 
 
 
 
92
  use_ddp = zero_stage == 0 and dp > 1
93
  ddp_grads_buffers = model_bf16 if use_ddp else 0
94
  overhead = 72 + 32 * mbs
 
96
  # Activations
97
  is_mha = num_key_value_heads == num_attention_heads
98
  decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
 
99
 
100
  if pp > 1:
101
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
 
106
  # Calculate aggregate metrics
107
  memory_usage_after_optimstates = (
108
  model_bf16 +
109
+ fp32_params +
110
  fp32_grads +
111
+ optimstates +
112
  ddp_grads_buffers +
113
+ zero3_buffers +
114
  overhead
115
  )
116
 
117
  memory_usage_before_optimstates = (
118
  model_bf16 +
119
+ fp32_params +
120
  fp32_grads +
121
+ ddp_grads_buffers +
122
+ zero3_buffers
123
  )
124
 
125
  memory_usage_peak_tbi = (
126
  model_bf16 +
127
+ fp32_params +
128
  fp32_grads +
129
+ optimstates +
130
  ddp_grads_buffers +
131
+ zero3_buffers +
132
  overhead +
133
  activs
134
  )
 
136
  return {
137
  "Components": {
138
  "Model BF16": model_bf16,
139
+ "FP32 Parameters": fp32_params,
140
  "FP32 Gradients": fp32_grads,
141
+ "Optimizer States": optimstates,
142
  "DDP Gradient Buffers": ddp_grads_buffers,
143
+ "ZeRO-3 Buffers": zero3_buffers,
144
  "Overhead": overhead,
145
  "Activations": activs
146
  },
 
204
  "Model Init": [
205
  ("Model BF16", c["Model BF16"]),
206
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
207
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
208
  ],
209
  "Gradient Accumulator Init": [
210
  ("Model BF16", c["Model BF16"]),
211
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
212
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
213
  ("FP32 Parameters", c["FP32 Parameters"]),
214
  ("FP32 Gradients", c["FP32 Gradients"])
215
  ],
216
  "Fwd-Bwd Peak": [
217
  ("Model BF16", c["Model BF16"]),
218
  ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
219
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
220
  ("FP32 Parameters", c["FP32 Parameters"]),
221
  ("FP32 Gradients", c["FP32 Gradients"]),
222
  ("Activations", c["Activations"])
223
  ],
224
  "Optimizer Step": [
225
  ("Model BF16", c["Model BF16"]),
226
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
227
  ("FP32 Parameters", c["FP32 Parameters"]),
228
  ("FP32 Gradients", c["FP32 Gradients"]),
229
  ("Optimizer States", c["Optimizer States"])
230
  ],
231
  "2nd Fwd-Bwd Peak": [
232
  ("Model BF16", c["Model BF16"]),
233
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
234
  ("FP32 Parameters", c["FP32 Parameters"]),
235
  ("FP32 Gradients", c["FP32 Gradients"]),
236
  ("Optimizer States", c["Optimizer States"]),
 
239
  ],
240
  "2nd Optimizer Step": [
241
  ("Model BF16", c["Model BF16"]),
242
+ ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
243
  ("FP32 Parameters", c["FP32 Parameters"]),
244
  ("FP32 Gradients", c["FP32 Gradients"]),
245
  ("Optimizer States", c["Optimizer States"]),