nouamanetazi HF Staff commited on
Commit
5f67cc3
·
1 Parent(s): 421f3af
Files changed (3) hide show
  1. .gitignore +175 -0
  2. app.py +145 -0
  3. utils.py +152 -0
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mem_viz/
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ .vscode
164
+
165
+ checkpoints/
166
+ wandb/
167
+
168
+
169
+ gg/
170
+ lighteval/
171
+ logs/
172
+ snapshots/
173
+ tb_logs*
174
+ .test_cache/
175
+ benchmark/
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import yaml
4
+ from pathlib import Path
5
+ import io
6
+ from utils import calculate_memory_components, plot_memory_breakdown
7
+
8
+
9
+ def load_config_from_yaml_content(yaml_content):
10
+ try:
11
+ config = yaml.safe_load(yaml_content)
12
+
13
+ # Extract relevant parameters from config
14
+ model_config = config['model']['model_config']
15
+ parallelism = config['parallelism']
16
+ tokens = config['tokens']
17
+ optimizer = config['optimizer']
18
+
19
+ return {
20
+ 'hidden_size': model_config['hidden_size'],
21
+ 'num_layers': model_config['num_hidden_layers'],
22
+ 'vocab_size': model_config['vocab_size'],
23
+ 'intermediate_size': model_config['intermediate_size'],
24
+ 'seq_len': tokens['sequence_length'],
25
+ 'mbs': tokens['micro_batch_size'],
26
+ 'batch_accum': tokens['batch_accumulation_per_replica'],
27
+ 'tp': parallelism['tp'],
28
+ 'pp': parallelism['pp'],
29
+ 'dp': parallelism['dp'],
30
+ 'zero_stage': optimizer['zero_stage'],
31
+ 'tie_word_embeddings': model_config['tie_word_embeddings']
32
+ }
33
+ except Exception as e:
34
+ raise gr.Error(f"Error parsing YAML: {str(e)}")
35
+
36
+ def load_config_from_yaml_file(yaml_path):
37
+ if not yaml_path:
38
+ return None
39
+ with open(yaml_path.name, 'r') as f:
40
+ return load_config_from_yaml_content(f.read())
41
+
42
+ def format_config_display(config):
43
+ if not config:
44
+ return "No configuration loaded"
45
+
46
+ sections = {
47
+ "Model Architecture": [
48
+ "hidden_size", "num_layers", "vocab_size",
49
+ "intermediate_size", "tie_word_embeddings"
50
+ ],
51
+ "Training Configuration": [
52
+ "seq_len", "mbs", "batch_accum"
53
+ ],
54
+ "Parallelism": [
55
+ "tp", "pp", "dp", "zero_stage"
56
+ ]
57
+ }
58
+
59
+ output = []
60
+ for section, params in sections.items():
61
+ output.append(f"\n### {section}")
62
+ for param in params:
63
+ output.append(f"- {param}: {config[param]}")
64
+
65
+ return "\n".join(output)
66
+
67
+ def process_yaml_and_plot(config):
68
+ if not config:
69
+ return None, None, "No configuration loaded"
70
+ fig1, fig2 = plot_memory_breakdown(**config)
71
+ return fig1, fig2, format_config_display(config)
72
+
73
+ with gr.Blocks() as demo:
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ with gr.Accordion("YAML Configuration", open=True):
77
+ yaml_file = gr.File(label="Upload YAML Config", file_types=[".yaml", ".yml"])
78
+ yaml_text = gr.Textbox(
79
+ label="Or paste YAML content here",
80
+ placeholder="Paste your YAML configuration here...",
81
+ lines=10
82
+ )
83
+ yaml_submit = gr.Button("Calculate Memory from YAML")
84
+
85
+ with gr.Accordion("Manual Configuration", open=False):
86
+ with gr.Accordion("Model Architecture", open=True):
87
+ hidden_size = gr.Number(4096, label="Hidden Size")
88
+ num_layers = gr.Number(32, label="Number of Layers")
89
+ vocab_size = gr.Number(50432, label="Vocabulary Size")
90
+ intermediate_size = gr.Number(11008, label="Intermediate Size")
91
+ tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings")
92
+
93
+ with gr.Accordion("Training Configuration", open=True):
94
+ seq_len = gr.Number(2048, label="Sequence Length")
95
+ mbs = gr.Number(1, label="Micro Batch Size")
96
+ batch_accum = gr.Number(1, label="Gradient Accumulation Steps")
97
+
98
+ with gr.Accordion("Parallelism", open=True):
99
+ tp = gr.Number(1, label="Tensor Parallelism")
100
+ pp = gr.Number(1, label="Pipeline Parallelism")
101
+ dp = gr.Number(1, label="Data Parallelism")
102
+ zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
103
+
104
+ manual_submit = gr.Button("Calculate Memory (Manual Input)")
105
+
106
+ with gr.Column(scale=2):
107
+ config_display = gr.Markdown(label="Configuration Values")
108
+ plot1 = gr.Plot(label="Memory Component Breakdown")
109
+ plot2 = gr.Plot(label="Aggregate Memory Metrics")
110
+
111
+ # Handle YAML file upload
112
+ yaml_file.change(
113
+ lambda x: process_yaml_and_plot(load_config_from_yaml_file(x) if x else None),
114
+ inputs=[yaml_file],
115
+ outputs=[plot1, plot2, config_display]
116
+ )
117
+
118
+ # Handle YAML text input
119
+ yaml_submit.click(
120
+ lambda x: process_yaml_and_plot(load_config_from_yaml_content(x) if x else None),
121
+ inputs=[yaml_text],
122
+ outputs=[plot1, plot2, config_display]
123
+ )
124
+
125
+ # Handle manual input
126
+ def manual_input_to_config(*args):
127
+ config = dict(zip([
128
+ 'hidden_size', 'num_layers', 'vocab_size', 'intermediate_size',
129
+ 'seq_len', 'mbs', 'batch_accum', 'tp', 'pp', 'dp', 'zero_stage',
130
+ 'tie_word_embeddings'
131
+ ], args))
132
+ return process_yaml_and_plot(config)
133
+
134
+ manual_submit.click(
135
+ manual_input_to_config,
136
+ inputs=[
137
+ hidden_size, num_layers, vocab_size, intermediate_size,
138
+ seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
139
+ tie_word_embeddings
140
+ ],
141
+ outputs=[plot1, plot2, config_display]
142
+ )
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()
utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+
4
+ def calculate_memory_components(
5
+ hidden_size, num_layers, vocab_size, intermediate_size,
6
+ seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
7
+ tie_word_embeddings
8
+ ):
9
+ # Calculate base components first
10
+ num_hidden_layers_in_pp = num_layers // pp
11
+
12
+ # Model BF16 calculation
13
+ vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1)
14
+
15
+ layer_params = (
16
+ (hidden_size * 3 * hidden_size) # qkv_proj
17
+ + (hidden_size * hidden_size) # out_proj
18
+ + (hidden_size * 2 * intermediate_size) # gate_up_proj
19
+ + (intermediate_size * hidden_size) # down_proj
20
+ )
21
+
22
+ model_bf16 = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
23
+
24
+ # Other components
25
+ dp_if_zero = 1 if zero_stage == 0 else dp
26
+ fp32_params = 2 * model_bf16
27
+ fp32_grads = 2 * model_bf16
28
+ optimstates = 4 * model_bf16
29
+ use_ddp = zero_stage == 0 and dp > 1
30
+ ddp_grads_buffers = model_bf16 if use_ddp else 0
31
+ overhead = 72 + 32 * mbs
32
+
33
+ # Activations
34
+ decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
35
+
36
+ if pp > 1:
37
+ activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
38
+ else:
39
+ cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
40
+ activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
41
+
42
+ # Calculate aggregate metrics
43
+ memory_usage_after_optimstates = (
44
+ model_bf16 +
45
+ fp32_params/dp_if_zero +
46
+ fp32_grads +
47
+ optimstates/dp_if_zero +
48
+ ddp_grads_buffers +
49
+ overhead
50
+ )
51
+
52
+ memory_usage_before_optimstates = (
53
+ model_bf16 +
54
+ fp32_params/dp_if_zero +
55
+ fp32_grads +
56
+ ddp_grads_buffers
57
+ )
58
+
59
+ memory_usage_peak_tbi = (
60
+ model_bf16 +
61
+ fp32_params/dp_if_zero +
62
+ fp32_grads +
63
+ optimstates/dp_if_zero +
64
+ ddp_grads_buffers +
65
+ overhead +
66
+ activs
67
+ )
68
+
69
+ return {
70
+ "Components": {
71
+ "Model BF16": model_bf16,
72
+ "FP32 Parameters": fp32_params/dp_if_zero,
73
+ "FP32 Gradients": fp32_grads,
74
+ "Optimizer States": optimstates/dp_if_zero,
75
+ "DDP Gradient Buffers": ddp_grads_buffers,
76
+ "Overhead": overhead,
77
+ "Activations": activs
78
+ },
79
+ "Aggregates": {
80
+ "Memory Before Optimizer States": memory_usage_before_optimstates,
81
+ "Memory After Optimizer States": memory_usage_after_optimstates,
82
+ "Peak Memory (TBI)": memory_usage_peak_tbi
83
+ }
84
+ }
85
+
86
+ def plot_memory_breakdown(
87
+ hidden_size, num_layers, vocab_size, intermediate_size,
88
+ seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
89
+ tie_word_embeddings
90
+ ):
91
+ results = calculate_memory_components(
92
+ hidden_size, num_layers, vocab_size, intermediate_size,
93
+ seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
94
+ tie_word_embeddings
95
+ )
96
+
97
+ # Create figure for components plot
98
+ plt.close('all')
99
+ fig1 = plt.figure(figsize=(10, 6))
100
+ ax1 = fig1.add_subplot(1, 1, 1)
101
+
102
+ # Plot components
103
+ components = results["Components"]
104
+ names = list(components.keys())
105
+ values = list(components.values())
106
+
107
+ bars1 = ax1.bar(range(len(components)), values)
108
+
109
+ # Add value labels with better positioning
110
+ for bar in bars1:
111
+ height = bar.get_height()
112
+ ax1.text(bar.get_x() + bar.get_width()/2., height,
113
+ f'{height:.1f} MiB',
114
+ ha='center', va='bottom',
115
+ rotation=0) # Remove rotation for better readability
116
+
117
+ # Customize the first plot
118
+ ax1.set_xticks(range(len(components)))
119
+ ax1.set_xticklabels(names, rotation=45, ha='right')
120
+ ax1.set_ylabel('Memory (MiB)')
121
+ ax1.set_title('Memory Component Breakdown', pad=20)
122
+
123
+ plt.tight_layout()
124
+
125
+ # Create figure for aggregates plot
126
+ fig2 = plt.figure(figsize=(10, 6))
127
+ ax2 = fig2.add_subplot(1, 1, 1)
128
+
129
+ # Plot aggregate metrics
130
+ aggregates = results["Aggregates"]
131
+ names = list(aggregates.keys())
132
+ values = list(aggregates.values())
133
+
134
+ bars2 = ax2.bar(range(len(aggregates)), values, color='orange')
135
+
136
+ # Add value labels
137
+ for bar in bars2:
138
+ height = bar.get_height()
139
+ ax2.text(bar.get_x() + bar.get_width()/2., height,
140
+ f'{height:.1f} MiB',
141
+ ha='center', va='bottom')
142
+
143
+ # Customize the second plot
144
+ ax2.set_xticks(range(len(aggregates)))
145
+ ax2.set_xticklabels(names, rotation=45, ha='right')
146
+ ax2.set_ylabel('Memory (MiB)')
147
+ ax2.set_title('Aggregate Memory Metrics', pad=20)
148
+
149
+ # Adjust layout to prevent text overlap
150
+ plt.tight_layout()
151
+
152
+ return fig1, fig2