Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import io | |
| import numpy as np | |
| import base64 | |
| # Color manipulation functions | |
| def hex_to_rgb(hex_color): | |
| hex_color = hex_color.lstrip('#') | |
| r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) | |
| return r, g, b | |
| def blend_colors(rgb, hex_color, blend_strength): | |
| other_rgb = hex_to_rgb(hex_color) | |
| return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3))) | |
| def increase_brightness(r, g, b, factor): | |
| return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b))) | |
| def decrease_brightness(r, g, b, factor): | |
| return tuple(map(lambda x: int(x * factor), (r, g, b))) | |
| def increase_saturation(r, g, b, factor) -> tuple[int, int, int]: | |
| gray = 0.299 * r + 0.587 * g + 0.114 * b | |
| return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b))) | |
| def rgb_to_hex(r, g, b): | |
| r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b)) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| # Color assignment function | |
| def get_color_for_config(config, filtered_on_compile_mode: bool = False): | |
| # Determine the main hue for the attention implementation | |
| attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"] | |
| compilation = config["compilation"] | |
| if attn_implementation == "eager": | |
| main_hue = "#FF4B4BFF" if compilation else "#FF4141FF" | |
| elif attn_implementation == "sdpa": | |
| main_hue = { | |
| None: "#4A90E2" if compilation else "#2E82E1FF", | |
| "math": "#408DDB" if compilation else "#227BD3FF", | |
| "flash_attention": "#35A34D" if compilation else "#219F3CFF", | |
| "efficient_attention": "#605895" if compilation else "#423691FF", | |
| "cudnn_attention": "#774AE2" if compilation else "#5D27DCFF", | |
| }[sdpa_backend] # fmt: off | |
| elif attn_implementation == "flash_attention_2": | |
| main_hue = "#FFD700" if compilation else "#FFBF00FF" | |
| else: | |
| raise ValueError(f"Unknown attention implementation: {attn_implementation}") | |
| # Apply color modifications for compilation and kernelization | |
| r, g, b = hex_to_rgb(main_hue) | |
| if config["compilation"]: | |
| delta = 0.2 | |
| delta += 0.2 * (len(config["compile_mode"]) - 7) / 8 if filtered_on_compile_mode else 0 | |
| r, g, b = increase_brightness(r, g, b, delta) | |
| if config["kernelize"]: | |
| pass | |
| # r, g, b = blend_colors((r, g, b), "#FF00F2FF", 0.7) | |
| r, g, b = decrease_brightness(r, g, b, 0.8) | |
| # r, g, b = increase_saturation(r, g, b, 0.9) | |
| # Return the color as a hex string | |
| return rgb_to_hex(r, g, b) | |
| def reorder_data(per_scenario_data: dict) -> dict: | |
| keys = list(per_scenario_data.keys()) | |
| def sorting_fn(key: str) -> float: | |
| cfg = per_scenario_data[key]["config"] | |
| attn_implementation = cfg["attn_implementation"] | |
| attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation] | |
| return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"] | |
| keys.sort(key=sorting_fn) | |
| per_scenario_data = {k: per_scenario_data[k] for k in keys} | |
| return per_scenario_data | |
| def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]: | |
| bar_kwargs = {"x": [], "height": [], "color": [], "label": []} | |
| errors = [] | |
| for i, (name, data) in enumerate(per_scenario_data.items()): | |
| bar_kwargs["x"].append(i) | |
| bar_kwargs["height"].append(np.median(data[key])) | |
| bar_kwargs["color"].append(get_color_for_config(data["config"])) | |
| bar_kwargs["label"].append(name) | |
| errors.append(np.std(data[key])) | |
| return bar_kwargs, errors | |
| def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str): | |
| ax.set_facecolor('#000000') | |
| # ax.grid(True, alpha=0.3, color='white') | |
| # Draw bars | |
| _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1) | |
| # Add error bars | |
| ax.errorbar( | |
| bar_kwargs["x"], bar_kwargs["height"], yerr=errors, | |
| fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, | |
| ) | |
| # Set labels and title | |
| ax.set_ylabel(ylabel, color='white', fontsize=16) | |
| ax.set_title(title, color='white', fontsize=18, pad=20) | |
| # Set ticks and grid | |
| ax.set_xticks([]) | |
| ax.tick_params(colors='white', labelsize=13) | |
| # Truncate axis to better fit the bars | |
| new_ymin, new_ymax = 1e9, -1e9 | |
| for h, e in zip(bar_kwargs["height"], errors): | |
| new_ymin = min(new_ymin, 0.98 * (h - e)) | |
| new_ymax = max(new_ymax, 1.02 * (h + e)) | |
| ymin, ymax = ax.get_ylim() | |
| ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax)) | |
| def create_matplotlib_bar_plot(per_scenario_data: dict): | |
| """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" | |
| # Create figure with dark theme - maximum size for full screen | |
| plt.style.use('dark_background') | |
| fig, axs = plt.subplots(1, 3, figsize=(30, 16)) | |
| fig.patch.set_facecolor('#000000') | |
| # Reorganize data | |
| per_scenario_data = reorder_data(per_scenario_data) | |
| # TTFT Plot (left) | |
| ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft") | |
| draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)") | |
| # ITL Plot (right) | |
| itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl") | |
| draw_bar_plot(axs[1], itl_bars, itl_errors, "Inter token latency (lower is better)", "ITL (seconds)") | |
| # E2E Plot (right) | |
| e2e_bars, e2e_errors = make_bar_kwargs(per_scenario_data, "e2e") | |
| draw_bar_plot(axs[2], e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)") | |
| # Add common legend with full text | |
| legend_labels = ttft_bars["label"] # Use full labels without truncation | |
| legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]] | |
| fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4, | |
| bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white', | |
| labelcolor='white', fontsize=14) | |
| # Save plot to bytes with high DPI for crisp text | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', facecolor='#000000', | |
| bbox_inches='tight', dpi=150) | |
| buffer.seek(0) | |
| # Convert to base64 for HTML embedding | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| plt.close(fig) | |
| # Return HTML with embedded image - full page coverage | |
| html = f""" | |
| <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;"> | |
| <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" /> | |
| </div> | |
| """ | |
| return html | |