nouamanetazi HF Staff commited on
Commit
4921bbf
·
1 Parent(s): c68510e
Files changed (1) hide show
  1. utils.py +19 -16
utils.py CHANGED
@@ -1,7 +1,12 @@
1
  import matplotlib.pyplot as plt
2
  import numpy as np
 
3
 
 
4
  def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size):
 
 
 
5
  # Get list of pipeline blocks and their costs
6
  pipeline_blocks = []
7
  block_costs = []
@@ -40,9 +45,9 @@ def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediat
40
  break
41
 
42
  num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer
43
- print("num_hidden_layers_in_pp", num_hidden_layers_in_pp)
44
  return num_hidden_layers_in_pp
45
 
 
46
  def calculate_memory_components(
47
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
48
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
@@ -77,9 +82,9 @@ def calculate_memory_components(
77
  overhead = 72 + 32 * mbs
78
 
79
  # Activations
80
- # decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
81
  is_mha = num_key_value_heads == num_attention_heads
82
- decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
 
83
 
84
  if pp > 1:
85
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
@@ -144,7 +149,7 @@ def plot_memory_breakdown(
144
 
145
  # Create figure for components plot
146
  plt.close('all')
147
- fig1 = plt.figure(figsize=(10, 6))
148
  ax1 = fig1.add_subplot(1, 1, 1)
149
 
150
  # Plot components
@@ -152,7 +157,10 @@ def plot_memory_breakdown(
152
  names = list(components.keys())
153
  values = list(components.values())
154
 
155
- bars1 = ax1.bar(range(len(components)), values)
 
 
 
156
 
157
  # Add value labels with better positioning
158
  for bar in bars1:
@@ -171,7 +179,7 @@ def plot_memory_breakdown(
171
  plt.tight_layout()
172
 
173
  # Create figure for timeline plot
174
- fig2 = plt.figure(figsize=(12, 6))
175
  ax2 = fig2.add_subplot(1, 1, 1)
176
 
177
  # Define timeline steps and their components
@@ -194,12 +202,6 @@ def plot_memory_breakdown(
194
  ("FP32 Gradients", c["FP32 Gradients"]),
195
  ("Activations", c["Activations"])
196
  ],
197
- "After Fwd-Bwd": [
198
- ("Model BF16", c["Model BF16"]),
199
- ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
200
- ("FP32 Parameters", c["FP32 Parameters"]),
201
- ("FP32 Gradients", c["FP32 Gradients"])
202
- ],
203
  "Optimizer Step": [
204
  ("Model BF16", c["Model BF16"]),
205
  ("FP32 Parameters", c["FP32 Parameters"]),
@@ -225,8 +227,7 @@ def plot_memory_breakdown(
225
  # Plot timeline
226
  x = range(len(timeline_steps))
227
  bottom = np.zeros(len(timeline_steps))
228
- colors = plt.cm.Set3(np.linspace(0, 1, len(c)))
229
- color_map = dict(zip(c.keys(), colors))
230
 
231
  for component in c.keys():
232
  heights = []
@@ -245,7 +246,7 @@ def plot_memory_breakdown(
245
  ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right')
246
  ax2.set_ylabel('Memory (MiB)')
247
  ax2.set_title('Memory Timeline', pad=20)
248
- ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
249
 
250
  # Add total memory labels on top of each bar
251
  for i, total in enumerate(bottom):
@@ -253,9 +254,11 @@ def plot_memory_breakdown(
253
 
254
  # Adjust layout
255
  plt.tight_layout()
256
-
257
  # Set y-axis limit
258
  max_y_value = max(bottom)
259
  ax2.set_ylim(0, max(80000, max_y_value))
260
 
 
 
 
261
  return fig1, fig2
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
+ import functools
4
 
5
+ @functools.lru_cache(maxsize=None)
6
  def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size):
7
+ if pp_size == 1:
8
+ return num_layers
9
+
10
  # Get list of pipeline blocks and their costs
11
  pipeline_blocks = []
12
  block_costs = []
 
45
  break
46
 
47
  num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer
 
48
  return num_hidden_layers_in_pp
49
 
50
+ @functools.lru_cache(maxsize=None)
51
  def calculate_memory_components(
52
  hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
53
  seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
 
82
  overhead = 72 + 32 * mbs
83
 
84
  # Activations
 
85
  is_mha = num_key_value_heads == num_attention_heads
86
+ decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
87
+ # decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
88
 
89
  if pp > 1:
90
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
 
149
 
150
  # Create figure for components plot
151
  plt.close('all')
152
+ fig1 = plt.figure(figsize=(10, 5))
153
  ax1 = fig1.add_subplot(1, 1, 1)
154
 
155
  # Plot components
 
157
  names = list(components.keys())
158
  values = list(components.values())
159
 
160
+ colors = plt.cm.Set3(np.linspace(0, 1, len(components)))
161
+ color_map = dict(zip(names, colors))
162
+
163
+ bars1 = ax1.bar(range(len(components)), values, color=colors)
164
 
165
  # Add value labels with better positioning
166
  for bar in bars1:
 
179
  plt.tight_layout()
180
 
181
  # Create figure for timeline plot
182
+ fig2 = plt.figure(figsize=(10, 6))
183
  ax2 = fig2.add_subplot(1, 1, 1)
184
 
185
  # Define timeline steps and their components
 
202
  ("FP32 Gradients", c["FP32 Gradients"]),
203
  ("Activations", c["Activations"])
204
  ],
 
 
 
 
 
 
205
  "Optimizer Step": [
206
  ("Model BF16", c["Model BF16"]),
207
  ("FP32 Parameters", c["FP32 Parameters"]),
 
227
  # Plot timeline
228
  x = range(len(timeline_steps))
229
  bottom = np.zeros(len(timeline_steps))
230
+
 
231
 
232
  for component in c.keys():
233
  heights = []
 
246
  ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right')
247
  ax2.set_ylabel('Memory (MiB)')
248
  ax2.set_title('Memory Timeline', pad=20)
249
+
250
 
251
  # Add total memory labels on top of each bar
252
  for i, total in enumerate(bottom):
 
254
 
255
  # Adjust layout
256
  plt.tight_layout()
 
257
  # Set y-axis limit
258
  max_y_value = max(bottom)
259
  ax2.set_ylim(0, max(80000, max_y_value))
260
 
261
+ # Add legend below the plot
262
+ # plt.subplots_adjust(bottom=0.8)
263
+ ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.5), ncol=3)
264
  return fig1, fig2