performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Refactor
55c8a69
raw
history blame
5.26 kB
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
# Color manipulation functions
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
return r, g, b
def increase_brightness(r, g, b, factor):
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
gray = 0.299 * r + 0.587 * g + 0.114 * b
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
def rgb_to_hex(r, g, b):
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
return f"#{r:02x}{g:02x}{b:02x}"
# Color assignment function
def get_color_for_config(config):
# Determine the main hue for the attention implementation
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
if attn_implementation == "eager":
main_hue = "#FF6B6B"
elif attn_implementation == "sdpa":
main_hue = {
None: "#4A90E2",
"math": "#408DDBFF",
"flash_attention": "#28767EFF",
"efficient_attention": "#605895FF",
"cudnn_attention": "#774AE2FF",
}[sdpa_backend]
elif attn_implementation == "flash_attention_2":
main_hue = "#FFD700"
else:
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
# Apply color modifications for compilation and kernelization
r, g, b = hex_to_rgb(main_hue)
if config["compilation"]:
r, g, b = increase_brightness(r, g, b, 0.3)
if config["kernelize"]:
r, g, b = increase_saturation(r, g, b, 0.8)
# Return the color as a hex string
return rgb_to_hex(r, g, b)
def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors = []
for i, (name, data) in enumerate(per_scenario_data.items()):
bar_kwargs["x"].append(i)
bar_kwargs["height"].append(np.median(data[key]))
bar_kwargs["color"].append(get_color_for_config(data["config"]))
bar_kwargs["label"].append(name)
errors.append(np.std(data[key]))
return bar_kwargs, errors
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
ax.set_facecolor('#000000')
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
)
# Set labels and title
ax.set_ylabel(ylabel, color='white', fontsize=14)
ax.set_title(title, color='white', fontsize=16, pad=20)
# Set ticks and grid
ax.set_xticks([])
ax.tick_params(colors='white')
ax.grid(True, alpha=0.3, color='white')
# Truncate axis to better fit the bars
# new_ymin, new_ymax = 1e9, -1e9
# for h, e in zip(bar_kwargs["height"], errors):
# new_ymin = min(new_ymin, 0.98 * (h - e))
# new_ymax = max(new_ymax, 1.02 * (h + e))
# ymin, ymax = ax.get_ylim()
# ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
def create_matplotlib_bar_plot(per_scenario_data: dict):
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - larger for more screen space
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12))
fig.patch.set_facecolor('#000000')
# TTFT Plot (left)
ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
# TPOT Plot (right)
itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)")
# Add common legend with full text
legend_labels = ttft_bars["label"] # Use full labels without truncation
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=12)
# Tight layout with spacing between subplots and extra bottom space for legend
# plt.subplots_adjust(wspace=0.3, bottom=0.075)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full height
html = f"""
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
</div>
"""
return html