nouamanetazi HF Staff commited on
Commit
1a15aaa
·
1 Parent(s): 2fa84c8

EXPERIMENTAL: add fsdp_checkpointing

Browse files
Files changed (2) hide show
  1. app.py +12 -7
  2. utils.py +11 -5
app.py CHANGED
@@ -27,7 +27,8 @@ def load_config_from_content(content):
27
  'zero_stage': 0, # Default value
28
  'tie_word_embeddings': config.get('tie_word_embeddings', True),
29
  'num_attention_heads': config['num_attention_heads'],
30
- 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads'])
 
31
  }
32
  except json.JSONDecodeError:
33
  # If not JSON, try YAML
@@ -53,7 +54,8 @@ def load_config_from_content(content):
53
  'zero_stage': optimizer['zero_stage'],
54
  'tie_word_embeddings': model_config['tie_word_embeddings'],
55
  'num_attention_heads': model_config['num_attention_heads'],
56
- 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads'])
 
57
  }
58
  except Exception as e:
59
  raise gr.Error(f"Error parsing configuration: {str(e)}")
@@ -77,7 +79,7 @@ def format_config_display(config):
77
  "seq_len", "mbs", "batch_accum"
78
  ],
79
  "Parallelism": [
80
- "tp", "pp", "dp", "zero_stage"
81
  ]
82
  }
83
 
@@ -134,6 +136,7 @@ with gr.Blocks() as demo:
134
  pp = gr.Number(1, label="Pipeline Parallelism")
135
  dp = gr.Number(1, label="Data Parallelism")
136
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
 
137
 
138
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
139
  with gr.Column(scale=2):
@@ -150,7 +153,7 @@ with gr.Blocks() as demo:
150
  plot1, plot2, config_display, oom_display,
151
  hidden_size, num_attention_heads, num_key_value_heads, num_layers,
152
  vocab_size, intermediate_size, seq_len, mbs, batch_accum,
153
- tp, pp, dp, zero_stage, tie_word_embeddings
154
  ]
155
  )
156
 
@@ -180,7 +183,8 @@ with gr.Blocks() as demo:
180
  config['pp'],
181
  config['dp'],
182
  config['zero_stage'],
183
- config['tie_word_embeddings']
 
184
  ]
185
 
186
  # Handle manual input
@@ -199,7 +203,8 @@ with gr.Blocks() as demo:
199
  'zero_stage': args[12],
200
  'tie_word_embeddings': args[13],
201
  'num_attention_heads': args[1],
202
- 'num_key_value_heads': args[2]
 
203
  }
204
  return process_yaml_and_update_ui(config)
205
 
@@ -208,7 +213,7 @@ with gr.Blocks() as demo:
208
  inputs=[
209
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
210
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
211
- tie_word_embeddings
212
  ],
213
  outputs=[plot1, plot2, config_display, oom_display]
214
  )
 
27
  'zero_stage': 0, # Default value
28
  'tie_word_embeddings': config.get('tie_word_embeddings', True),
29
  'num_attention_heads': config['num_attention_heads'],
30
+ 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']),
31
+ 'fsdp_checkpointing': False # Default value
32
  }
33
  except json.JSONDecodeError:
34
  # If not JSON, try YAML
 
54
  'zero_stage': optimizer['zero_stage'],
55
  'tie_word_embeddings': model_config['tie_word_embeddings'],
56
  'num_attention_heads': model_config['num_attention_heads'],
57
+ 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']),
58
+ 'fsdp_checkpointing': optimizer.get('fsdp_checkpointing', False) # Add FSDP checkpointing from config
59
  }
60
  except Exception as e:
61
  raise gr.Error(f"Error parsing configuration: {str(e)}")
 
79
  "seq_len", "mbs", "batch_accum"
80
  ],
81
  "Parallelism": [
82
+ "tp", "pp", "dp", "zero_stage", "fsdp_checkpointing"
83
  ]
84
  }
85
 
 
136
  pp = gr.Number(1, label="Pipeline Parallelism")
137
  dp = gr.Number(1, label="Data Parallelism")
138
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
139
+ fsdp_checkpointing = gr.Checkbox(False, label="FSDP Activation Checkpointing")
140
 
141
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
142
  with gr.Column(scale=2):
 
153
  plot1, plot2, config_display, oom_display,
154
  hidden_size, num_attention_heads, num_key_value_heads, num_layers,
155
  vocab_size, intermediate_size, seq_len, mbs, batch_accum,
156
+ tp, pp, dp, zero_stage, tie_word_embeddings, fsdp_checkpointing
157
  ]
158
  )
159
 
 
183
  config['pp'],
184
  config['dp'],
185
  config['zero_stage'],
186
+ config['tie_word_embeddings'],
187
+ config['fsdp_checkpointing']
188
  ]
189
 
190
  # Handle manual input
 
203
  'zero_stage': args[12],
204
  'tie_word_embeddings': args[13],
205
  'num_attention_heads': args[1],
206
+ 'num_key_value_heads': args[2],
207
+ 'fsdp_checkpointing': args[14] # Add FSDP checkpointing
208
  }
209
  return process_yaml_and_update_ui(config)
210
 
 
213
  inputs=[
214
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
215
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
216
+ tie_word_embeddings, fsdp_checkpointing # Add FSDP checkpointing
217
  ],
218
  outputs=[plot1, plot2, config_display, oom_display]
219
  )
utils.py CHANGED
@@ -51,7 +51,7 @@ def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediat
51
  def calculate_memory_components(
52
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
53
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
54
- tie_word_embeddings
55
  ):
56
  # Calculate base components first
57
  if pp == 1:
@@ -93,7 +93,7 @@ def calculate_memory_components(
93
  ddp_grads_buffers = model_bf16 if use_ddp else 0
94
  overhead = 72 + 32 * mbs
95
 
96
- # Activations
97
  is_mha = num_key_value_heads == num_attention_heads
98
  decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
99
 
@@ -101,7 +101,13 @@ def calculate_memory_components(
101
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
102
  else:
103
  cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
104
- activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
 
 
 
 
 
 
105
 
106
  # Calculate aggregate metrics
107
  memory_usage_after_optimstates = (
@@ -154,12 +160,12 @@ def calculate_memory_components(
154
  def plot_memory_breakdown(
155
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
156
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
157
- tie_word_embeddings
158
  ):
159
  results = calculate_memory_components(
160
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
161
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
162
- tie_word_embeddings
163
  )
164
  memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
165
 
 
51
  def calculate_memory_components(
52
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
53
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
54
+ tie_word_embeddings, fsdp_checkpointing=False
55
  ):
56
  # Calculate base components first
57
  if pp == 1:
 
93
  ddp_grads_buffers = model_bf16 if use_ddp else 0
94
  overhead = 72 + 32 * mbs
95
 
96
+ # Activations calculation with FSDP checkpointing support
97
  is_mha = num_key_value_heads == num_attention_heads
98
  decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
99
 
 
101
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
102
  else:
103
  cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
104
+ base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
105
+
106
+ # Apply activation reduction for FSDP checkpointing in ZeRO-3
107
+ if zero_stage == 3 and fsdp_checkpointing:
108
+ activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing
109
+ else:
110
+ activs = base_activs
111
 
112
  # Calculate aggregate metrics
113
  memory_usage_after_optimstates = (
 
160
  def plot_memory_breakdown(
161
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
162
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
163
+ tie_word_embeddings, fsdp_checkpointing=False
164
  ):
165
  results = calculate_memory_components(
166
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
167
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
168
+ tie_word_embeddings, fsdp_checkpointing
169
  )
170
  memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
171