Spaces:
Sleeping
Sleeping
More info in title
Browse files- bar_plot.py +25 -16
- data.py +10 -8
bar_plot.py
CHANGED
|
@@ -75,24 +75,35 @@ def create_matplotlib_bar_plot() -> None:
|
|
| 75 |
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
|
| 76 |
fig.patch.set_facecolor('#000000')
|
| 77 |
|
| 78 |
-
# Load and
|
| 79 |
per_device_data = load_data()
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# TTFT Plot (left)
|
| 86 |
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
|
| 87 |
-
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "
|
| 88 |
|
| 89 |
# # ITL Plot (right)
|
| 90 |
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
|
| 91 |
-
draw_bar_plot(axs[1], itl_bars, itl_errors,
|
| 92 |
-
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
plt.tight_layout()
|
| 97 |
|
| 98 |
# Add common legend with full text
|
|
@@ -102,7 +113,7 @@ def create_matplotlib_bar_plot() -> None:
|
|
| 102 |
|
| 103 |
# Put a legend to the right of the current axis
|
| 104 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
| 105 |
-
bbox_to_anchor=(0.515, -0.
|
| 106 |
labelcolor='white', fontsize=14)
|
| 107 |
|
| 108 |
# Save plot to bytes with high DPI for crisp text
|
|
@@ -124,7 +135,7 @@ def create_matplotlib_bar_plot() -> None:
|
|
| 124 |
return html
|
| 125 |
|
| 126 |
|
| 127 |
-
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list,
|
| 128 |
ax.set_facecolor('#000000')
|
| 129 |
ax.grid(True, alpha=0.2, color='white', zorder=0)
|
| 130 |
# Draw bars
|
|
@@ -134,10 +145,8 @@ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylab
|
|
| 134 |
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 135 |
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
|
| 136 |
)
|
| 137 |
-
# Set labels and
|
| 138 |
ax.set_ylabel(ylabel, color='white', fontsize=16)
|
| 139 |
-
ax.set_title(title, color='white', fontsize=18, pad=20)
|
| 140 |
-
# Set ticks and grid
|
| 141 |
ax.set_xticks([])
|
| 142 |
ax.tick_params(colors='white', labelsize=13)
|
| 143 |
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
|
|
|
|
| 75 |
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
|
| 76 |
fig.patch.set_facecolor('#000000')
|
| 77 |
|
| 78 |
+
# Load data and ensure coherence
|
| 79 |
per_device_data = load_data()
|
| 80 |
+
batch_size, sequence_length, num_tokens_to_generate = None, None, None
|
| 81 |
+
for device_name, device_data in per_device_data.items():
|
| 82 |
+
bs, seqlen, n_tok = device_data.ensure_coherence()
|
| 83 |
+
if batch_size is None:
|
| 84 |
+
batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
|
| 85 |
+
elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
|
| 86 |
+
fig.suptitle(
|
| 87 |
+
f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
|
| 88 |
+
f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
|
| 89 |
+
color='white', fontsize=18, pad=20
|
| 90 |
+
)
|
| 91 |
+
return None
|
| 92 |
|
| 93 |
# TTFT Plot (left)
|
| 94 |
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
|
| 95 |
+
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
|
| 96 |
|
| 97 |
# # ITL Plot (right)
|
| 98 |
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
|
| 99 |
+
draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
|
| 100 |
+
|
| 101 |
+
# Title and tight layout
|
| 102 |
+
title = "\n".join([
|
| 103 |
+
"Time to first token and inter-token latency (lower is better)",
|
| 104 |
+
f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
|
| 105 |
+
])
|
| 106 |
+
fig.suptitle(title, color='white', fontsize=20, y=1.005)
|
| 107 |
plt.tight_layout()
|
| 108 |
|
| 109 |
# Add common legend with full text
|
|
|
|
| 113 |
|
| 114 |
# Put a legend to the right of the current axis
|
| 115 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
| 116 |
+
bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
|
| 117 |
labelcolor='white', fontsize=14)
|
| 118 |
|
| 119 |
# Save plot to bytes with high DPI for crisp text
|
|
|
|
| 135 |
return html
|
| 136 |
|
| 137 |
|
| 138 |
+
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, ylabel: str, xticks: list[tuple[float, str]]):
|
| 139 |
ax.set_facecolor('#000000')
|
| 140 |
ax.grid(True, alpha=0.2, color='white', zorder=0)
|
| 141 |
# Draw bars
|
|
|
|
| 145 |
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 146 |
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
|
| 147 |
)
|
| 148 |
+
# Set labels, ticks and grid
|
| 149 |
ax.set_ylabel(ylabel, color='white', fontsize=16)
|
|
|
|
|
|
|
| 150 |
ax.set_xticks([])
|
| 151 |
ax.tick_params(colors='white', labelsize=13)
|
| 152 |
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
|
data.py
CHANGED
|
@@ -25,15 +25,17 @@ class ModelBenchmarkData:
|
|
| 25 |
num_tokens = len(measures["t_tokens"]) - 1
|
| 26 |
return delta_t / num_tokens
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
|
| 30 |
for cfg_name, data in self.data.items():
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
|
| 39 |
# Gather data for each scenario
|
|
|
|
| 25 |
num_tokens = len(measures["t_tokens"]) - 1
|
| 26 |
return delta_t / num_tokens
|
| 27 |
|
| 28 |
+
def ensure_coherence(self) -> tuple[int, int, int]:
|
| 29 |
+
all_hyperparams = set()
|
| 30 |
for cfg_name, data in self.data.items():
|
| 31 |
+
config = data["metadata"]["config"]
|
| 32 |
+
hyperparams = (config["batch_size"], config["sequence_length"], config["num_tokens_to_generate"])
|
| 33 |
+
all_hyperparams.add(hyperparams)
|
| 34 |
+
if len(all_hyperparams) > 1:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f"Different batch size, sequence length or nb of tokens to generate between configs: {all_hyperparams}"
|
| 37 |
+
)
|
| 38 |
+
return all_hyperparams.pop()
|
| 39 |
|
| 40 |
def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
|
| 41 |
# Gather data for each scenario
|