Spaces:
Running
Running
File size: 6,824 Bytes
55c8a69 dc41c89 55c8a69 e1f4b73 55c8a69 165f130 9f6e83d 165f130 dc41c89 55c8a69 dc41c89 165f130 dc41c89 165f130 dc41c89 55c8a69 e1f4b73 55c8a69 dc41c89 55c8a69 22cf82d 165f130 22cf82d 165f130 55c8a69 165f130 22cf82d 55c8a69 dc41c89 165f130 22cf82d dc41c89 55c8a69 165f130 dc41c89 e1f4b73 22cf82d e1f4b73 55c8a69 e1f4b73 55c8a69 e1f4b73 55c8a69 dc41c89 22cf82d dc41c89 165f130 dc41c89 22cf82d dc41c89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
if config["attn_implementation"] == "eager":
attn_implementation = "Eager"
elif config["attn_implementation"] == "flash_attention_2":
attn_implementation = "Flash attention"
elif config["attn_implementation"] == "sdpa":
attn_implementation = {
"flash_attention": "SDPA (flash attention)",
"efficient_attention": "SDPA (efficient_attention)",
"cudnn_attention": "SDPA (cudnn)",
"math": "SDPA (math)",
}.get(config["sdpa_backend"], "SDPA (unknown backend)")
else:
attn_implementation = "Unknown"
compile = "compiled" if config["compilation"] else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use('dark_background')
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
fig.patch.set_facecolor('#000000')
# Load data and ensure coherence
per_device_data = load_data()
batch_size, sequence_length, num_tokens_to_generate = None, None, None
for device_name, device_data in per_device_data.items():
bs, seqlen, n_tok = device_data.ensure_coherence()
if batch_size is None:
batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
fig.suptitle(
f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
color='white', fontsize=18, pad=20
)
return None
# TTFT Plot (left)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
# # ITL Plot (right)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
# Title and tight layout
title = "\n".join([
"Time to first token and inter-token latency (lower is better)",
f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
])
fig.suptitle(title, color='white', fontsize=20, y=1.005)
plt.tight_layout()
# Add common legend with full text
unique_bars = len(ttft_bars["label"]) // 2
legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
# Put a legend to the right of the current axis
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=14)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, ylabel: str, xticks: list[tuple[float, str]]):
ax.set_facecolor('#000000')
ax.grid(True, alpha=0.2, color='white', zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
)
# Set labels, ticks and grid
ax.set_ylabel(ylabel, color='white', fontsize=16)
ax.set_xticks([])
ax.tick_params(colors='white', labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
|