nouamanetazi HF Staff commited on
Commit
5a41adf
·
1 Parent(s): ed9dd0d

support VLMs

Browse files
Files changed (2) hide show
  1. app.py +47 -25
  2. utils.py +4 -4
app.py CHANGED
@@ -12,24 +12,46 @@ def load_config_from_content(content):
12
  # Try parsing as JSON first
13
  try:
14
  config = json.loads(content)
15
- # Convert JSON HF config format to our format
16
- return {
17
- 'hidden_size': config['hidden_size'],
18
- 'num_layers': config['num_hidden_layers'],
19
- 'vocab_size': config['vocab_size'],
20
- 'intermediate_size': config['intermediate_size'],
21
- 'seq_len': 2048, # Default value since not in config
22
- 'mbs': 1, # Default value
23
- 'batch_accum': 1, # Default value
24
- 'tp': 1, # Default value
25
- 'pp': 1, # Default value
26
- 'dp': 1, # Default value
27
- 'zero_stage': 0, # Default value
28
- 'tie_word_embeddings': config.get('tie_word_embeddings', True),
29
- 'num_attention_heads': config['num_attention_heads'],
30
- 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']),
31
- 'fsdp_checkpointing': False # Default value
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except json.JSONDecodeError:
34
  # If not JSON, try YAML
35
  config = yaml.safe_load(content)
@@ -55,7 +77,7 @@ def load_config_from_content(content):
55
  'tie_word_embeddings': model_config['tie_word_embeddings'],
56
  'num_attention_heads': model_config['num_attention_heads'],
57
  'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']),
58
- 'fsdp_checkpointing': optimizer.get('fsdp_checkpointing', False) # Add FSDP checkpointing from config
59
  }
60
  except Exception as e:
61
  raise gr.Error(f"Error parsing configuration: {str(e)}")
@@ -92,7 +114,7 @@ def format_config_display(config):
92
  "seq_len", "mbs", "batch_accum"
93
  ],
94
  "Parallelism": [
95
- "tp", "pp", "dp", "zero_stage", "fsdp_checkpointing"
96
  ]
97
  }
98
 
@@ -154,7 +176,7 @@ with gr.Blocks() as demo:
154
  pp = gr.Number(1, label="Pipeline Parallelism")
155
  dp = gr.Number(1, label="Data Parallelism")
156
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
157
- fsdp_checkpointing = gr.Checkbox(False, label="FSDP Activation Checkpointing")
158
 
159
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
160
  with gr.Column(scale=2):
@@ -171,7 +193,7 @@ with gr.Blocks() as demo:
171
  plot1, plot2, config_display, oom_display,
172
  hidden_size, num_attention_heads, num_key_value_heads, num_layers,
173
  vocab_size, intermediate_size, seq_len, mbs, batch_accum,
174
- tp, pp, dp, zero_stage, tie_word_embeddings, fsdp_checkpointing
175
  ]
176
  )
177
 
@@ -202,7 +224,7 @@ with gr.Blocks() as demo:
202
  config['dp'],
203
  config['zero_stage'],
204
  config['tie_word_embeddings'],
205
- config['fsdp_checkpointing']
206
  ]
207
 
208
  # Handle manual input
@@ -222,7 +244,7 @@ with gr.Blocks() as demo:
222
  'tie_word_embeddings': args[13],
223
  'num_attention_heads': args[1],
224
  'num_key_value_heads': args[2],
225
- 'fsdp_checkpointing': args[14] # Add FSDP checkpointing
226
  }
227
  return process_yaml_and_update_ui(config)
228
 
@@ -231,7 +253,7 @@ with gr.Blocks() as demo:
231
  inputs=[
232
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
233
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
234
- tie_word_embeddings, fsdp_checkpointing # Add FSDP checkpointing
235
  ],
236
  outputs=[plot1, plot2, config_display, oom_display]
237
  )
 
12
  # Try parsing as JSON first
13
  try:
14
  config = json.loads(content)
15
+ # Check if this is a multimodal model with text_config
16
+ if 'text_config' in config:
17
+ # Use text_config for model parameters
18
+ text_config = config['text_config']
19
+ return {
20
+ 'hidden_size': text_config['hidden_size'],
21
+ 'num_layers': text_config['num_hidden_layers'],
22
+ 'vocab_size': config.get('vocab_size', 256000), # Default for multimodal models
23
+ 'intermediate_size': text_config['intermediate_size'],
24
+ 'seq_len': 2048, # Default value since not in config
25
+ 'mbs': 1, # Default value
26
+ 'batch_accum': 1, # Default value
27
+ 'tp': 1, # Default value
28
+ 'pp': 1, # Default value
29
+ 'dp': 1, # Default value
30
+ 'zero_stage': 0, # Default value
31
+ 'tie_word_embeddings': config.get('tie_word_embeddings', True),
32
+ 'num_attention_heads': text_config['num_attention_heads'],
33
+ 'num_key_value_heads': text_config.get('num_key_value_heads', text_config['num_attention_heads']),
34
+ 'full_checkpointing': False # Default value
35
+ }
36
+ else:
37
+ # Original code for non-multimodal models
38
+ return {
39
+ 'hidden_size': config['hidden_size'],
40
+ 'num_layers': config['num_hidden_layers'],
41
+ 'vocab_size': config['vocab_size'],
42
+ 'intermediate_size': config['intermediate_size'],
43
+ 'seq_len': 2048, # Default value since not in config
44
+ 'mbs': 1, # Default value
45
+ 'batch_accum': 1, # Default value
46
+ 'tp': 1, # Default value
47
+ 'pp': 1, # Default value
48
+ 'dp': 1, # Default value
49
+ 'zero_stage': 0, # Default value
50
+ 'tie_word_embeddings': config.get('tie_word_embeddings', True),
51
+ 'num_attention_heads': config['num_attention_heads'],
52
+ 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']),
53
+ 'full_checkpointing': False # Default value
54
+ }
55
  except json.JSONDecodeError:
56
  # If not JSON, try YAML
57
  config = yaml.safe_load(content)
 
77
  'tie_word_embeddings': model_config['tie_word_embeddings'],
78
  'num_attention_heads': model_config['num_attention_heads'],
79
  'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']),
80
+ 'full_checkpointing': optimizer.get('full_checkpointing', False) # Renamed from fsdp_checkpointing
81
  }
82
  except Exception as e:
83
  raise gr.Error(f"Error parsing configuration: {str(e)}")
 
114
  "seq_len", "mbs", "batch_accum"
115
  ],
116
  "Parallelism": [
117
+ "tp", "pp", "dp", "zero_stage", "full_checkpointing"
118
  ]
119
  }
120
 
 
176
  pp = gr.Number(1, label="Pipeline Parallelism")
177
  dp = gr.Number(1, label="Data Parallelism")
178
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
179
+ full_checkpointing = gr.Checkbox(False, label="Full Activation Checkpointing")
180
 
181
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
182
  with gr.Column(scale=2):
 
193
  plot1, plot2, config_display, oom_display,
194
  hidden_size, num_attention_heads, num_key_value_heads, num_layers,
195
  vocab_size, intermediate_size, seq_len, mbs, batch_accum,
196
+ tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing
197
  ]
198
  )
199
 
 
224
  config['dp'],
225
  config['zero_stage'],
226
  config['tie_word_embeddings'],
227
+ config['full_checkpointing']
228
  ]
229
 
230
  # Handle manual input
 
244
  'tie_word_embeddings': args[13],
245
  'num_attention_heads': args[1],
246
  'num_key_value_heads': args[2],
247
+ 'full_checkpointing': args[14] # Renamed from fsdp_checkpointing
248
  }
249
  return process_yaml_and_update_ui(config)
250
 
 
253
  inputs=[
254
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
255
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
256
+ tie_word_embeddings, full_checkpointing # Renamed from fsdp_checkpointing
257
  ],
258
  outputs=[plot1, plot2, config_display, oom_display]
259
  )
utils.py CHANGED
@@ -51,7 +51,7 @@ def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediat
51
  def calculate_memory_components(
52
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
53
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
54
- tie_word_embeddings, fsdp_checkpointing=False
55
  ):
56
  # Calculate base components first
57
  if pp == 1:
@@ -107,7 +107,7 @@ def calculate_memory_components(
107
  base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
108
 
109
  # Apply activation reduction for FSDP checkpointing in ZeRO-3
110
- if zero_stage == 3 and fsdp_checkpointing:
111
  activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing
112
  else:
113
  activs = base_activs
@@ -163,12 +163,12 @@ def calculate_memory_components(
163
  def plot_memory_breakdown(
164
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
165
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
166
- tie_word_embeddings, fsdp_checkpointing=False
167
  ):
168
  results = calculate_memory_components(
169
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
170
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
171
- tie_word_embeddings, fsdp_checkpointing
172
  )
173
  memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
174
 
 
51
  def calculate_memory_components(
52
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
53
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
54
+ tie_word_embeddings, full_checkpointing=False
55
  ):
56
  # Calculate base components first
57
  if pp == 1:
 
107
  base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
108
 
109
  # Apply activation reduction for FSDP checkpointing in ZeRO-3
110
+ if zero_stage == 3 and full_checkpointing:
111
  activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing
112
  else:
113
  activs = base_activs
 
163
  def plot_memory_breakdown(
164
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
165
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
166
+ tie_word_embeddings, full_checkpointing=False
167
  ):
168
  results = calculate_memory_components(
169
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
170
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
171
+ tie_word_embeddings, full_checkpointing
172
  )
173
  memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
174