performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Better plot
e1f4b73
raw
history blame
6.95 kB
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
# Color manipulation functions
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
return r, g, b
def blend_colors(rgb, hex_color, blend_strength):
other_rgb = hex_to_rgb(hex_color)
return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
def increase_brightness(r, g, b, factor):
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
def decrease_brightness(r, g, b, factor):
return tuple(map(lambda x: int(x * factor), (r, g, b)))
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
gray = 0.299 * r + 0.587 * g + 0.114 * b
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
def rgb_to_hex(r, g, b):
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
return f"#{r:02x}{g:02x}{b:02x}"
# Color assignment function
def get_color_for_config(config, filtered_on_compile_mode: bool = False):
# Determine the main hue for the attention implementation
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
compilation = config["compilation"]
if attn_implementation == "eager":
main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
elif attn_implementation == "sdpa":
main_hue = {
None: "#4A90E2" if compilation else "#2E82E1FF",
"math": "#408DDB" if compilation else "#227BD3FF",
"flash_attention": "#35A34D" if compilation else "#219F3CFF",
"efficient_attention": "#605895" if compilation else "#423691FF",
"cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
}[sdpa_backend] # fmt: off
elif attn_implementation == "flash_attention_2":
main_hue = "#FFD700" if compilation else "#FFBF00FF"
else:
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
# Apply color modifications for compilation and kernelization
r, g, b = hex_to_rgb(main_hue)
if config["compilation"]:
delta = 0.2
delta += 0.2 * (len(config["compile_mode"]) - 7) / 8 if filtered_on_compile_mode else 0
r, g, b = increase_brightness(r, g, b, delta)
if config["kernelize"]:
pass
# r, g, b = blend_colors((r, g, b), "#FF00F2FF", 0.7)
r, g, b = decrease_brightness(r, g, b, 0.8)
# r, g, b = increase_saturation(r, g, b, 0.9)
# Return the color as a hex string
return rgb_to_hex(r, g, b)
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors = []
for i, (name, data) in enumerate(per_scenario_data.items()):
bar_kwargs["x"].append(i)
bar_kwargs["height"].append(np.median(data[key]))
bar_kwargs["color"].append(get_color_for_config(data["config"]))
bar_kwargs["label"].append(name)
errors.append(np.std(data[key]))
return bar_kwargs, errors
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
ax.set_facecolor('#000000')
# ax.grid(True, alpha=0.3, color='white')
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
)
# Set labels and title
ax.set_ylabel(ylabel, color='white', fontsize=16)
ax.set_title(title, color='white', fontsize=18, pad=20)
# Set ticks and grid
ax.set_xticks([])
ax.tick_params(colors='white', labelsize=13)
# Truncate axis to better fit the bars
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
def create_matplotlib_bar_plot(per_scenario_data: dict):
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use('dark_background')
fig, axs = plt.subplots(1, 3, figsize=(30, 16))
fig.patch.set_facecolor('#000000')
# Reorganize data
per_scenario_data = reorder_data(per_scenario_data)
# TTFT Plot (left)
ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
# ITL Plot (right)
itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, "Inter token latency (lower is better)", "ITL (seconds)")
# E2E Plot (right)
e2e_bars, e2e_errors = make_bar_kwargs(per_scenario_data, "e2e")
draw_bar_plot(axs[2], e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
# Add common legend with full text
legend_labels = ttft_bars["label"] # Use full labels without truncation
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=14)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html