performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Probably v1
79e7993
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data, ModelBenchmarkData
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_impl_prio = {
"flash_attention_2": 0,
"sdpa": 1,
"eager": 2,
"flex_attention": 3,
}[attn_implementation]
sdpa_backend_prio = {
None: -1,
"flash_attention": 0,
"math": 1,
"efficient_attention": 2,
"cudnn_attention": 3,
}[cfg["sdpa_backend"]]
return (
attn_impl_prio,
sdpa_backend_prio,
cfg["kernelize"],
cfg["compile_mode"] is not None,
)
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
if config["attn_implementation"] == "eager":
attn_implementation = "Eager"
elif config["attn_implementation"] == "flash_attention_2":
attn_implementation = "Flash attention"
elif config["attn_implementation"] == "flex_attention":
attn_implementation = "Flex attention"
elif config["attn_implementation"] == "sdpa":
attn_implementation = {
"flash_attention": "SDPA (flash attention)",
"efficient_attention": "SDPA (efficient_attention)",
"cudnn_attention": "SDPA (cudnn)",
"math": "SDPA (math)",
}.get(config["sdpa_backend"], "SDPA (unknown backend)")
else:
attn_implementation = "Unknown"
compile = "compiled" if config["compile_mode"] is not None else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def infer_bar_hatch(config: dict) -> str:
if config["compile_mode"] is not None:
return "/"
else:
return ""
def make_bar_kwargs(
per_device_data: dict[str, ModelBenchmarkData], key: str
) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use("dark_background")
fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16
fig.patch.set_facecolor("#000000")
# Load data and ensure coherence
per_device_data = load_data()
batch_size, sequence_length, num_tokens_to_generate = None, None, None
for device_name, device_data in per_device_data.items():
bs, seqlen, n_tok = device_data.ensure_coherence()
if batch_size is None:
batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
elif (bs, seqlen, n_tok) != (
batch_size,
sequence_length,
num_tokens_to_generate,
):
fig.suptitle(
f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
color="white",
fontsize=18,
)
return None
# TTFT Plot (top)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
# # ITL Plot (bottom)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
# Title and tight layout
title = "\n".join(
[
"Time to first token and inter-token latency (lower is better)",
f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
]
)
fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5)
plt.tight_layout()
# Add common legend with full text
legend_labels, legend_colors, legend_hatches = [], [], []
for label, color, hatch in zip(
ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"]
):
if label not in legend_labels:
legend_labels.append(label)
legend_colors.append(color)
legend_hatches.append(hatch)
# Make sure all attn implementations are equally represented
# implementations = {}
# for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches):
# impl = label.split(",")[0]
# implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)]
# n_max = max(len(impls) for impls in implementations.values())
# for label_color_pairs in implementations.values():
# for _ in range(len(label_color_pairs), n_max):
# label_color_pairs.append(("", "#000000"))
# legend_labels, legend_colors = zip(*sum(implementations.values(), []))
legend_handles = [
mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white")
for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels)
]
# Put a legend to the right of the current axis
fig.legend(
handles=legend_handles,
loc="lower center",
ncol=4,
bbox_to_anchor=(0.515, -0.11),
facecolor="black",
edgecolor="white",
labelcolor="white",
fontsize=14,
)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(
ax: plt.Axes,
bar_kwargs: dict,
errors: list,
ylabel: str,
xticks: list[tuple[float, str]],
adapt_ylim: bool = False,
) -> None:
ax.set_facecolor("#000000")
ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"],
bar_kwargs["height"],
yerr=errors,
fmt="none",
ecolor="white",
alpha=0.8,
elinewidth=1.5,
capthick=1.5,
capsize=4,
zorder=4,
)
# Set labels, ticks and grid
ax.set_ylabel(ylabel, color="white", fontsize=16)
ax.set_xticks([])
ax.tick_params(colors="white", labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
if adapt_ylim:
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))