performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Better look
165f130
raw
history blame
5.95 kB
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
attn_implementation = {
"flash_attention_2": "Flash attention",
"sdpa": "SDPA",
"eager": "Eager",
}[config["attn_implementation"]]
compile = "compiled" if config["compilation"] else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use('dark_background')
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
fig.patch.set_facecolor('#000000')
# Load and sanitize data
per_device_data = load_data()
batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
if len(set(batch_sizes.values())) > 1:
fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
return None
# TTFT Plot (left)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
# # ITL Plot (right)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
# # E2E Plot (right)
# e2e_bars, e2e_errors = make_bar_kwargs("e2e")
# draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
plt.tight_layout()
# Add common legend with full text
unique_bars = len(ttft_bars["label"]) // 2
legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
# Put a legend to the right of the current axis
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=14)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
ax.set_facecolor('#000000')
ax.grid(True, alpha=0.2, color='white', zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
)
# Set labels and title
ax.set_ylabel(ylabel, color='white', fontsize=16)
ax.set_title(title, color='white', fontsize=18, pad=20)
# Set ticks and grid
ax.set_xticks([])
ax.tick_params(colors='white', labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))