Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import io | |
| import numpy as np | |
| import base64 | |
| # Color manipulation functions | |
| def hex_to_rgb(hex_color): | |
| hex_color = hex_color.lstrip('#') | |
| r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) | |
| return r, g, b | |
| def increase_brightness(r, g, b, factor): | |
| return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b))) | |
| def increase_saturation(r, g, b, factor) -> tuple[int, int, int]: | |
| gray = 0.299 * r + 0.587 * g + 0.114 * b | |
| return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b))) | |
| def rgb_to_hex(r, g, b): | |
| r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b)) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| # Color assignment function | |
| def get_color_for_config(config): | |
| # Determine the main hue for the attention implementation | |
| attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"] | |
| if attn_implementation == "eager": | |
| main_hue = "#FF6B6B" | |
| elif attn_implementation == "sdpa": | |
| main_hue = { | |
| None: "#4A90E2", | |
| "math": "#408DDBFF", | |
| "flash_attention": "#28767EFF", | |
| "efficient_attention": "#605895FF", | |
| "cudnn_attention": "#774AE2FF", | |
| }[sdpa_backend] | |
| elif attn_implementation == "flash_attention_2": | |
| main_hue = "#FFD700" | |
| else: | |
| raise ValueError(f"Unknown attention implementation: {attn_implementation}") | |
| # Apply color modifications for compilation and kernelization | |
| r, g, b = hex_to_rgb(main_hue) | |
| if config["compilation"]: | |
| r, g, b = increase_brightness(r, g, b, 0.3) | |
| if config["kernelize"]: | |
| r, g, b = increase_saturation(r, g, b, 0.8) | |
| # Return the color as a hex string | |
| return rgb_to_hex(r, g, b) | |
| def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]: | |
| bar_kwargs = {"x": [], "height": [], "color": [], "label": []} | |
| errors = [] | |
| for i, (name, data) in enumerate(per_scenario_data.items()): | |
| bar_kwargs["x"].append(i) | |
| bar_kwargs["height"].append(np.median(data[key])) | |
| bar_kwargs["color"].append(get_color_for_config(data["config"])) | |
| bar_kwargs["label"].append(name) | |
| errors.append(np.std(data[key])) | |
| return bar_kwargs, errors | |
| def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str): | |
| ax.set_facecolor('#000000') | |
| # Draw bars | |
| _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1) | |
| # Add error bars | |
| ax.errorbar( | |
| bar_kwargs["x"], bar_kwargs["height"], yerr=errors, | |
| fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, | |
| ) | |
| # Set labels and title | |
| ax.set_ylabel(ylabel, color='white', fontsize=14) | |
| ax.set_title(title, color='white', fontsize=16, pad=20) | |
| # Set ticks and grid | |
| ax.set_xticks([]) | |
| ax.tick_params(colors='white') | |
| ax.grid(True, alpha=0.3, color='white') | |
| # Truncate axis to better fit the bars | |
| # new_ymin, new_ymax = 1e9, -1e9 | |
| # for h, e in zip(bar_kwargs["height"], errors): | |
| # new_ymin = min(new_ymin, 0.98 * (h - e)) | |
| # new_ymax = max(new_ymax, 1.02 * (h + e)) | |
| # ymin, ymax = ax.get_ylim() | |
| # ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax)) | |
| def create_matplotlib_bar_plot(per_scenario_data: dict): | |
| """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" | |
| # Create figure with dark theme - larger for more screen space | |
| plt.style.use('dark_background') | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12)) | |
| fig.patch.set_facecolor('#000000') | |
| # TTFT Plot (left) | |
| ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft") | |
| draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)") | |
| # TPOT Plot (right) | |
| itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl") | |
| draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)") | |
| # Add common legend with full text | |
| legend_labels = ttft_bars["label"] # Use full labels without truncation | |
| legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]] | |
| fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1, | |
| bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white', | |
| labelcolor='white', fontsize=12) | |
| # Tight layout with spacing between subplots and extra bottom space for legend | |
| # plt.subplots_adjust(wspace=0.3, bottom=0.075) | |
| # Save plot to bytes with high DPI for crisp text | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', facecolor='#000000', | |
| bbox_inches='tight', dpi=150) | |
| buffer.seek(0) | |
| # Convert to base64 for HTML embedding | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| plt.close(fig) | |
| # Return HTML with embedded image - full height | |
| html = f""" | |
| <div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;"> | |
| <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" /> | |
| </div> | |
| """ | |
| return html | |