import matplotlib.pyplot as plt import io import numpy as np import base64 # Color manipulation functions def hex_to_rgb(hex_color): hex_color = hex_color.lstrip('#') r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) return r, g, b def increase_brightness(r, g, b, factor): return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b))) def increase_saturation(r, g, b, factor) -> tuple[int, int, int]: gray = 0.299 * r + 0.587 * g + 0.114 * b return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b))) def rgb_to_hex(r, g, b): r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b)) return f"#{r:02x}{g:02x}{b:02x}" # Color assignment function def get_color_for_config(config): # Determine the main hue for the attention implementation attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"] if attn_implementation == "eager": main_hue = "#FF6B6B" elif attn_implementation == "sdpa": main_hue = { None: "#4A90E2", "math": "#408DDBFF", "flash_attention": "#28767EFF", "efficient_attention": "#605895FF", "cudnn_attention": "#774AE2FF", }[sdpa_backend] elif attn_implementation == "flash_attention_2": main_hue = "#FFD700" else: raise ValueError(f"Unknown attention implementation: {attn_implementation}") # Apply color modifications for compilation and kernelization r, g, b = hex_to_rgb(main_hue) if config["compilation"]: r, g, b = increase_brightness(r, g, b, 0.3) if config["kernelize"]: r, g, b = increase_saturation(r, g, b, 0.8) # Return the color as a hex string return rgb_to_hex(r, g, b) def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]: bar_kwargs = {"x": [], "height": [], "color": [], "label": []} errors = [] for i, (name, data) in enumerate(per_scenario_data.items()): bar_kwargs["x"].append(i) bar_kwargs["height"].append(np.median(data[key])) bar_kwargs["color"].append(get_color_for_config(data["config"])) bar_kwargs["label"].append(name) errors.append(np.std(data[key])) return bar_kwargs, errors def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str): ax.set_facecolor('#000000') # Draw bars _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1) # Add error bars ax.errorbar( bar_kwargs["x"], bar_kwargs["height"], yerr=errors, fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, ) # Set labels and title ax.set_ylabel(ylabel, color='white', fontsize=14) ax.set_title(title, color='white', fontsize=16, pad=20) # Set ticks and grid ax.set_xticks([]) ax.tick_params(colors='white') ax.grid(True, alpha=0.3, color='white') # Truncate axis to better fit the bars # new_ymin, new_ymax = 1e9, -1e9 # for h, e in zip(bar_kwargs["height"], errors): # new_ymin = min(new_ymin, 0.98 * (h - e)) # new_ymax = max(new_ymax, 1.02 * (h + e)) # ymin, ymax = ax.get_ylim() # ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax)) def create_matplotlib_bar_plot(per_scenario_data: dict): """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" # Create figure with dark theme - larger for more screen space plt.style.use('dark_background') fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12)) fig.patch.set_facecolor('#000000') # TTFT Plot (left) ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft") draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)") # TPOT Plot (right) itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl") draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)") # Add common legend with full text legend_labels = ttft_bars["label"] # Use full labels without truncation legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]] fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1, bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white', labelcolor='white', fontsize=12) # Tight layout with spacing between subplots and extra bottom space for legend # plt.subplots_adjust(wspace=0.3, bottom=0.075) # Save plot to bytes with high DPI for crisp text buffer = io.BytesIO() plt.savefig(buffer, format='png', facecolor='#000000', bbox_inches='tight', dpi=150) buffer.seek(0) # Convert to base64 for HTML embedding img_data = base64.b64encode(buffer.getvalue()).decode() plt.close(fig) # Return HTML with embedded image - full height html = f"""
""" return html