ror HF Staff commited on
Commit
22cf82d
·
1 Parent(s): 9f6e83d

More info in title

Browse files
Files changed (2) hide show
  1. bar_plot.py +25 -16
  2. data.py +10 -8
bar_plot.py CHANGED
@@ -75,24 +75,35 @@ def create_matplotlib_bar_plot() -> None:
75
  fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
76
  fig.patch.set_facecolor('#000000')
77
 
78
- # Load and sanitize data
79
  per_device_data = load_data()
80
- batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
81
- if len(set(batch_sizes.values())) > 1:
82
- fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
83
- return None
 
 
 
 
 
 
 
 
84
 
85
  # TTFT Plot (left)
86
  ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
87
- draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
88
 
89
  # # ITL Plot (right)
90
  itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
91
- draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
92
-
93
- # # E2E Plot (right)
94
- # e2e_bars, e2e_errors = make_bar_kwargs("e2e")
95
- # draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
 
 
 
96
  plt.tight_layout()
97
 
98
  # Add common legend with full text
@@ -102,7 +113,7 @@ def create_matplotlib_bar_plot() -> None:
102
 
103
  # Put a legend to the right of the current axis
104
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
105
- bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
106
  labelcolor='white', fontsize=14)
107
 
108
  # Save plot to bytes with high DPI for crisp text
@@ -124,7 +135,7 @@ def create_matplotlib_bar_plot() -> None:
124
  return html
125
 
126
 
127
- def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
128
  ax.set_facecolor('#000000')
129
  ax.grid(True, alpha=0.2, color='white', zorder=0)
130
  # Draw bars
@@ -134,10 +145,8 @@ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylab
134
  bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
135
  fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
136
  )
137
- # Set labels and title
138
  ax.set_ylabel(ylabel, color='white', fontsize=16)
139
- ax.set_title(title, color='white', fontsize=18, pad=20)
140
- # Set ticks and grid
141
  ax.set_xticks([])
142
  ax.tick_params(colors='white', labelsize=13)
143
  ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
 
75
  fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
76
  fig.patch.set_facecolor('#000000')
77
 
78
+ # Load data and ensure coherence
79
  per_device_data = load_data()
80
+ batch_size, sequence_length, num_tokens_to_generate = None, None, None
81
+ for device_name, device_data in per_device_data.items():
82
+ bs, seqlen, n_tok = device_data.ensure_coherence()
83
+ if batch_size is None:
84
+ batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
85
+ elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
86
+ fig.suptitle(
87
+ f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
88
+ f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
89
+ color='white', fontsize=18, pad=20
90
+ )
91
+ return None
92
 
93
  # TTFT Plot (left)
94
  ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
95
+ draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
96
 
97
  # # ITL Plot (right)
98
  itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
99
+ draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
100
+
101
+ # Title and tight layout
102
+ title = "\n".join([
103
+ "Time to first token and inter-token latency (lower is better)",
104
+ f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
105
+ ])
106
+ fig.suptitle(title, color='white', fontsize=20, y=1.005)
107
  plt.tight_layout()
108
 
109
  # Add common legend with full text
 
113
 
114
  # Put a legend to the right of the current axis
115
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
116
+ bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
117
  labelcolor='white', fontsize=14)
118
 
119
  # Save plot to bytes with high DPI for crisp text
 
135
  return html
136
 
137
 
138
+ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, ylabel: str, xticks: list[tuple[float, str]]):
139
  ax.set_facecolor('#000000')
140
  ax.grid(True, alpha=0.2, color='white', zorder=0)
141
  # Draw bars
 
145
  bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
146
  fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
147
  )
148
+ # Set labels, ticks and grid
149
  ax.set_ylabel(ylabel, color='white', fontsize=16)
 
 
150
  ax.set_xticks([])
151
  ax.tick_params(colors='white', labelsize=13)
152
  ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
data.py CHANGED
@@ -25,15 +25,17 @@ class ModelBenchmarkData:
25
  num_tokens = len(measures["t_tokens"]) - 1
26
  return delta_t / num_tokens
27
 
28
- def get_main_batch_size(self) -> int:
29
- batch_sizes = {}
30
  for cfg_name, data in self.data.items():
31
- for measure in data["measures"]:
32
- bs = measure["batch_size"]
33
- if bs not in batch_sizes:
34
- batch_sizes[bs] = 0
35
- batch_sizes[bs] += 1
36
- return max(batch_sizes, key=batch_sizes.get)
 
 
37
 
38
  def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
39
  # Gather data for each scenario
 
25
  num_tokens = len(measures["t_tokens"]) - 1
26
  return delta_t / num_tokens
27
 
28
+ def ensure_coherence(self) -> tuple[int, int, int]:
29
+ all_hyperparams = set()
30
  for cfg_name, data in self.data.items():
31
+ config = data["metadata"]["config"]
32
+ hyperparams = (config["batch_size"], config["sequence_length"], config["num_tokens_to_generate"])
33
+ all_hyperparams.add(hyperparams)
34
+ if len(all_hyperparams) > 1:
35
+ raise ValueError(
36
+ f"Different batch size, sequence length or nb of tokens to generate between configs: {all_hyperparams}"
37
+ )
38
+ return all_hyperparams.pop()
39
 
40
  def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
41
  # Gather data for each scenario