nouamanetazi HF Staff commited on
Commit
2fa84c8
·
1 Parent(s): 9d879a4

qol updated

Browse files
Files changed (1) hide show
  1. app.py +112 -59
app.py CHANGED
@@ -1,45 +1,68 @@
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import yaml
 
4
  from pathlib import Path
5
  import io
6
  from utils import calculate_memory_components, plot_memory_breakdown
7
 
8
 
9
- def load_config_from_yaml_content(yaml_content):
10
  try:
11
- config = yaml.safe_load(yaml_content)
12
-
13
- # Extract relevant parameters from config
14
- model_config = config['model']['model_config']
15
- parallelism = config['parallelism']
16
- tokens = config['tokens']
17
- optimizer = config['optimizer']
18
-
19
- return {
20
- 'hidden_size': model_config['hidden_size'],
21
- 'num_layers': model_config['num_hidden_layers'],
22
- 'vocab_size': model_config['vocab_size'],
23
- 'intermediate_size': model_config['intermediate_size'],
24
- 'seq_len': tokens['sequence_length'],
25
- 'mbs': tokens['micro_batch_size'],
26
- 'batch_accum': tokens['batch_accumulation_per_replica'],
27
- 'tp': parallelism['tp'],
28
- 'pp': parallelism['pp'],
29
- 'dp': parallelism['dp'],
30
- 'zero_stage': optimizer['zero_stage'],
31
- 'tie_word_embeddings': model_config['tie_word_embeddings'],
32
- 'num_attention_heads': model_config['num_attention_heads'],
33
- 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads'])
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
- raise gr.Error(f"Error parsing YAML: {str(e)}")
37
 
38
  def load_config_from_yaml_file(yaml_path):
39
  if not yaml_path:
40
  return None
41
  with open(yaml_path.name, 'r') as f:
42
- return load_config_from_yaml_content(f.read())
43
 
44
  def format_config_display(config):
45
  if not config:
@@ -75,38 +98,41 @@ def process_yaml_and_plot(config):
75
  fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
76
  oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
77
  return fig1, fig2, format_config_display(config), oom_prediction
78
-
79
  with gr.Blocks() as demo:
80
  with gr.Row():
81
  with gr.Column(scale=1):
82
- with gr.Accordion("YAML Configuration", open=True):
83
- yaml_file = gr.File(label="Upload YAML Config", file_types=[".yaml", ".yml"])
84
- yaml_text = gr.Textbox(
85
- label="Or paste YAML content here",
86
- placeholder="Paste your YAML configuration here...",
87
  lines=10
88
  )
89
- yaml_submit = gr.Button("Calculate Memory from YAML")
90
 
91
- with gr.Accordion("Manual Configuration", open=False):
92
  with gr.Accordion("Model Architecture", open=True):
93
- hidden_size = gr.Number(4096, label="Hidden Size")
94
- num_layers = gr.Number(32, label="Number of Layers")
95
- vocab_size = gr.Number(50432, label="Vocabulary Size")
96
- intermediate_size = gr.Number(11008, label="Intermediate Size")
 
 
 
 
 
97
  tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings")
98
- num_attention_heads = gr.Number(32, label="Number of Attention Heads")
99
- num_key_value_heads = gr.Number(32, label="Number of Key Value Heads")
100
 
101
  with gr.Accordion("Training Configuration", open=True):
102
- seq_len = gr.Number(2048, label="Sequence Length")
103
- mbs = gr.Number(1, label="Micro Batch Size")
104
- batch_accum = gr.Number(1, label="Gradient Accumulation Steps")
 
105
 
106
  with gr.Accordion("Parallelism", open=True):
107
- tp = gr.Number(1, label="Tensor Parallelism")
108
- pp = gr.Number(1, label="Pipeline Parallelism")
109
- dp = gr.Number(1, label="Data Parallelism")
 
110
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
111
 
112
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
@@ -116,19 +142,46 @@ with gr.Blocks() as demo:
116
  plot1 = gr.Plot(label="Memory Component Breakdown")
117
  plot2 = gr.Plot(label="Aggregate Memory Metrics")
118
 
119
- # Handle YAML file upload
120
- yaml_file.change(
121
- lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
122
- inputs=[yaml_file],
123
- outputs=[plot1, plot2, config_display, oom_display]
 
 
 
 
 
124
  )
125
 
126
- # Handle YAML text input
127
- yaml_submit.click(
128
- lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
129
- inputs=[yaml_text],
130
- outputs=[plot1, plot2, config_display, oom_display]
131
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # Handle manual input
134
  def manual_input_to_config(*args):
@@ -148,7 +201,7 @@ with gr.Blocks() as demo:
148
  'num_attention_heads': args[1],
149
  'num_key_value_heads': args[2]
150
  }
151
- return process_yaml_and_plot(config)
152
 
153
  manual_submit.click(
154
  manual_input_to_config,
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import yaml
4
+ import json
5
  from pathlib import Path
6
  import io
7
  from utils import calculate_memory_components, plot_memory_breakdown
8
 
9
 
10
+ def load_config_from_content(content):
11
  try:
12
+ # Try parsing as JSON first
13
+ try:
14
+ config = json.loads(content)
15
+ # Convert JSON HF config format to our format
16
+ return {
17
+ 'hidden_size': config['hidden_size'],
18
+ 'num_layers': config['num_hidden_layers'],
19
+ 'vocab_size': config['vocab_size'],
20
+ 'intermediate_size': config['intermediate_size'],
21
+ 'seq_len': 2048, # Default value since not in config
22
+ 'mbs': 1, # Default value
23
+ 'batch_accum': 1, # Default value
24
+ 'tp': 1, # Default value
25
+ 'pp': 1, # Default value
26
+ 'dp': 1, # Default value
27
+ 'zero_stage': 0, # Default value
28
+ 'tie_word_embeddings': config.get('tie_word_embeddings', True),
29
+ 'num_attention_heads': config['num_attention_heads'],
30
+ 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads'])
31
+ }
32
+ except json.JSONDecodeError:
33
+ # If not JSON, try YAML
34
+ config = yaml.safe_load(content)
35
+
36
+ # Extract relevant parameters from YAML config
37
+ model_config = config['model']['model_config']
38
+ parallelism = config['parallelism']
39
+ tokens = config['tokens']
40
+ optimizer = config['optimizer']
41
+
42
+ return {
43
+ 'hidden_size': model_config['hidden_size'],
44
+ 'num_layers': model_config['num_hidden_layers'],
45
+ 'vocab_size': model_config['vocab_size'],
46
+ 'intermediate_size': model_config['intermediate_size'],
47
+ 'seq_len': tokens['sequence_length'],
48
+ 'mbs': tokens['micro_batch_size'],
49
+ 'batch_accum': tokens['batch_accumulation_per_replica'],
50
+ 'tp': parallelism['tp'],
51
+ 'pp': parallelism['pp'],
52
+ 'dp': parallelism['dp'],
53
+ 'zero_stage': optimizer['zero_stage'],
54
+ 'tie_word_embeddings': model_config['tie_word_embeddings'],
55
+ 'num_attention_heads': model_config['num_attention_heads'],
56
+ 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads'])
57
+ }
58
  except Exception as e:
59
+ raise gr.Error(f"Error parsing configuration: {str(e)}")
60
 
61
  def load_config_from_yaml_file(yaml_path):
62
  if not yaml_path:
63
  return None
64
  with open(yaml_path.name, 'r') as f:
65
+ return load_config_from_content(f.read())
66
 
67
  def format_config_display(config):
68
  if not config:
 
98
  fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
99
  oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
100
  return fig1, fig2, format_config_display(config), oom_prediction
 
101
  with gr.Blocks() as demo:
102
  with gr.Row():
103
  with gr.Column(scale=1):
104
+ with gr.Accordion("Configuration Input", open=True):
105
+ config_text = gr.Textbox(
106
+ label="Paste YAML or JSON configuration",
107
+ placeholder="Paste your YAML or JSON configuration here...",
 
108
  lines=10
109
  )
110
+ config_submit = gr.Button("Calculate Memory from Config")
111
 
112
+ with gr.Accordion("Manual Configuration", open=True):
113
  with gr.Accordion("Model Architecture", open=True):
114
+ with gr.Row():
115
+ hidden_size = gr.Number(4096, label="Hidden Size")
116
+ num_layers = gr.Number(32, label="Number of Layers")
117
+ with gr.Row():
118
+ vocab_size = gr.Number(50432, label="Vocabulary Size")
119
+ intermediate_size = gr.Number(11008, label="Intermediate Size")
120
+ with gr.Row():
121
+ num_attention_heads = gr.Number(32, label="Number of Attention Heads")
122
+ num_key_value_heads = gr.Number(32, label="Number of Key Value Heads")
123
  tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings")
 
 
124
 
125
  with gr.Accordion("Training Configuration", open=True):
126
+ with gr.Row():
127
+ seq_len = gr.Number(2048, label="Sequence Length")
128
+ mbs = gr.Number(1, label="Micro Batch Size")
129
+ batch_accum = gr.Number(1, label="Gradient Accumulation Steps")
130
 
131
  with gr.Accordion("Parallelism", open=True):
132
+ with gr.Row():
133
+ tp = gr.Number(1, label="Tensor Parallelism")
134
+ pp = gr.Number(1, label="Pipeline Parallelism")
135
+ dp = gr.Number(1, label="Data Parallelism")
136
  zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
137
 
138
  manual_submit = gr.Button("Calculate Memory (Manual Input)")
 
142
  plot1 = gr.Plot(label="Memory Component Breakdown")
143
  plot2 = gr.Plot(label="Aggregate Memory Metrics")
144
 
145
+ # Handle config text input
146
+ config_submit.click(
147
+ lambda x: process_yaml_and_update_ui(load_config_from_content(x) if x else None),
148
+ inputs=[config_text],
149
+ outputs=[
150
+ plot1, plot2, config_display, oom_display,
151
+ hidden_size, num_attention_heads, num_key_value_heads, num_layers,
152
+ vocab_size, intermediate_size, seq_len, mbs, batch_accum,
153
+ tp, pp, dp, zero_stage, tie_word_embeddings
154
+ ]
155
  )
156
 
157
+ def process_yaml_and_update_ui(config):
158
+ if not config:
159
+ return [None, None, "No configuration loaded", None] + [gr.update() for _ in range(14)]
160
+
161
+ fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
162
+ oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
163
+
164
+ # Return values for all outputs including UI updates
165
+ return [
166
+ fig1, fig2,
167
+ format_config_display(config),
168
+ oom_prediction,
169
+ # UI component updates
170
+ config['hidden_size'],
171
+ config['num_attention_heads'],
172
+ config['num_key_value_heads'],
173
+ config['num_layers'],
174
+ config['vocab_size'],
175
+ config['intermediate_size'],
176
+ config['seq_len'],
177
+ config['mbs'],
178
+ config['batch_accum'],
179
+ config['tp'],
180
+ config['pp'],
181
+ config['dp'],
182
+ config['zero_stage'],
183
+ config['tie_word_embeddings']
184
+ ]
185
 
186
  # Handle manual input
187
  def manual_input_to_config(*args):
 
201
  'num_attention_heads': args[1],
202
  'num_key_value_heads': args[2]
203
  }
204
+ return process_yaml_and_update_ui(config)
205
 
206
  manual_submit.click(
207
  manual_input_to_config,