performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Small fixes and data update
c4f3c79
raw
history blame
6.97 kB
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
if config["attn_implementation"] == "eager":
attn_implementation = "Eager"
elif config["attn_implementation"] == "flash_attention_2":
attn_implementation = "Flash attention"
elif config["attn_implementation"] == "sdpa":
attn_implementation = {
"flash_attention": "SDPA (flash attention)",
"efficient_attention": "SDPA (efficient_attention)",
"cudnn_attention": "SDPA (cudnn)",
"math": "SDPA (math)",
}.get(config["sdpa_backend"], "SDPA (unknown backend)")
else:
attn_implementation = "Unknown"
compile = "compiled" if config["compilation"] else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use('dark_background')
fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16
fig.patch.set_facecolor('#000000')
# Load data and ensure coherence
per_device_data = load_data()
batch_size, sequence_length, num_tokens_to_generate = None, None, None
for device_name, device_data in per_device_data.items():
bs, seqlen, n_tok = device_data.ensure_coherence()
if batch_size is None:
batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
fig.suptitle(
f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
color='white', fontsize=18
)
return None
# TTFT Plot (left)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
# # ITL Plot (right)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
# Title and tight layout
title = "\n".join([
"Time to first token and inter-token latency (lower is better)",
f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
])
fig.suptitle(title, color='white', fontsize=20, y=1.005, linespacing=1.5)
plt.tight_layout()
# Add common legend with full text
unique_bars = len(ttft_bars["label"]) // 2
legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
# Put a legend to the right of the current axis
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=14)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(
ax: plt.Axes,
bar_kwargs: dict,
errors: list,
ylabel: str,
xticks: list[tuple[float, str]],
adapt_ylim: bool = False,
) -> None:
ax.set_facecolor('#000000')
ax.grid(True, alpha=0.3, color='white', axis='y', zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
)
# Set labels, ticks and grid
ax.set_ylabel(ylabel, color='white', fontsize=16)
ax.set_xticks([])
ax.tick_params(colors='white', labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
if adapt_ylim:
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))