Spaces:
Running
Running
File size: 5,262 Bytes
55c8a69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
# Color manipulation functions
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
return r, g, b
def increase_brightness(r, g, b, factor):
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
gray = 0.299 * r + 0.587 * g + 0.114 * b
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
def rgb_to_hex(r, g, b):
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
return f"#{r:02x}{g:02x}{b:02x}"
# Color assignment function
def get_color_for_config(config):
# Determine the main hue for the attention implementation
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
if attn_implementation == "eager":
main_hue = "#FF6B6B"
elif attn_implementation == "sdpa":
main_hue = {
None: "#4A90E2",
"math": "#408DDBFF",
"flash_attention": "#28767EFF",
"efficient_attention": "#605895FF",
"cudnn_attention": "#774AE2FF",
}[sdpa_backend]
elif attn_implementation == "flash_attention_2":
main_hue = "#FFD700"
else:
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
# Apply color modifications for compilation and kernelization
r, g, b = hex_to_rgb(main_hue)
if config["compilation"]:
r, g, b = increase_brightness(r, g, b, 0.3)
if config["kernelize"]:
r, g, b = increase_saturation(r, g, b, 0.8)
# Return the color as a hex string
return rgb_to_hex(r, g, b)
def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors = []
for i, (name, data) in enumerate(per_scenario_data.items()):
bar_kwargs["x"].append(i)
bar_kwargs["height"].append(np.median(data[key]))
bar_kwargs["color"].append(get_color_for_config(data["config"]))
bar_kwargs["label"].append(name)
errors.append(np.std(data[key]))
return bar_kwargs, errors
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
ax.set_facecolor('#000000')
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
)
# Set labels and title
ax.set_ylabel(ylabel, color='white', fontsize=14)
ax.set_title(title, color='white', fontsize=16, pad=20)
# Set ticks and grid
ax.set_xticks([])
ax.tick_params(colors='white')
ax.grid(True, alpha=0.3, color='white')
# Truncate axis to better fit the bars
# new_ymin, new_ymax = 1e9, -1e9
# for h, e in zip(bar_kwargs["height"], errors):
# new_ymin = min(new_ymin, 0.98 * (h - e))
# new_ymax = max(new_ymax, 1.02 * (h + e))
# ymin, ymax = ax.get_ylim()
# ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
def create_matplotlib_bar_plot(per_scenario_data: dict):
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - larger for more screen space
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12))
fig.patch.set_facecolor('#000000')
# TTFT Plot (left)
ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
# TPOT Plot (right)
itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)")
# Add common legend with full text
legend_labels = ttft_bars["label"] # Use full labels without truncation
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=12)
# Tight layout with spacing between subplots and extra bottom space for legend
# plt.subplots_adjust(wspace=0.3, bottom=0.075)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full height
html = f"""
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
</div>
"""
return html
|