Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import io | |
| import numpy as np | |
| import base64 | |
| from plot_utils import get_color_for_config | |
| from data import load_data, ModelBenchmarkData | |
| def reorder_data(per_scenario_data: dict) -> dict: | |
| keys = list(per_scenario_data.keys()) | |
| def sorting_fn(key: str) -> float: | |
| cfg = per_scenario_data[key]["config"] | |
| attn_implementation = cfg["attn_implementation"] | |
| attn_impl_prio = { | |
| "flash_attention_2": 0, | |
| "sdpa": 1, | |
| "eager": 2, | |
| "flex_attention": 3, | |
| }[attn_implementation] | |
| sdpa_backend_prio = { | |
| None: -1, | |
| "flash_attention": 0, | |
| "math": 1, | |
| "efficient_attention": 2, | |
| "cudnn_attention": 3, | |
| }[cfg["sdpa_backend"]] | |
| return ( | |
| attn_impl_prio, | |
| sdpa_backend_prio, | |
| cfg["kernelize"], | |
| cfg["compile_mode"] is not None, | |
| ) | |
| keys.sort(key=sorting_fn) | |
| per_scenario_data = {k: per_scenario_data[k] for k in keys} | |
| return per_scenario_data | |
| def infer_bar_label(config: dict) -> str: | |
| """Format legend labels to be more readable.""" | |
| if config["attn_implementation"] == "eager": | |
| attn_implementation = "Eager" | |
| elif config["attn_implementation"] == "flash_attention_2": | |
| attn_implementation = "Flash attention" | |
| elif config["attn_implementation"] == "flex_attention": | |
| attn_implementation = "Flex attention" | |
| elif config["attn_implementation"] == "sdpa": | |
| attn_implementation = { | |
| "flash_attention": "SDPA (flash attention)", | |
| "efficient_attention": "SDPA (efficient_attention)", | |
| "cudnn_attention": "SDPA (cudnn)", | |
| "math": "SDPA (math)", | |
| }.get(config["sdpa_backend"], "SDPA (unknown backend)") | |
| else: | |
| attn_implementation = "Unknown" | |
| compile = "compiled" if config["compile_mode"] is not None else "no compile" | |
| kernels = "kernelized" if config["kernelize"] else "no kernels" | |
| return f"{attn_implementation}, {compile}, {kernels}" | |
| def infer_bar_hatch(config: dict) -> str: | |
| if config["compile_mode"] is not None: | |
| return "/" | |
| else: | |
| return "" | |
| def make_bar_kwargs( | |
| per_device_data: dict[str, ModelBenchmarkData], key: str | |
| ) -> tuple[dict, list]: | |
| # Prepare accumulators | |
| current_x = 0 | |
| bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []} | |
| errors_bars = [] | |
| x_ticks = [] | |
| for device_name, device_data in per_device_data.items(): | |
| per_scenario_data = device_data.get_bar_plot_data() | |
| per_scenario_data = reorder_data(per_scenario_data) | |
| device_xs = [] | |
| for scenario_name, scenario_data in per_scenario_data.items(): | |
| bar_kwargs["x"].append(current_x) | |
| bar_kwargs["height"].append(np.median(scenario_data[key])) | |
| bar_kwargs["color"].append(get_color_for_config(scenario_data["config"])) | |
| bar_kwargs["label"].append(infer_bar_label(scenario_data["config"])) | |
| bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"])) | |
| errors_bars.append(np.std(scenario_data[key])) | |
| device_xs.append(current_x) | |
| current_x += 1 | |
| x_ticks.append((np.mean(device_xs), device_name)) | |
| current_x += 1.5 | |
| return bar_kwargs, errors_bars, x_ticks | |
| def create_matplotlib_bar_plot() -> None: | |
| """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" | |
| # Create figure with dark theme - maximum size for full screen | |
| plt.style.use("dark_background") | |
| fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16 | |
| fig.patch.set_facecolor("#000000") | |
| # Load data and ensure coherence | |
| per_device_data = load_data() | |
| batch_size, sequence_length, num_tokens_to_generate = None, None, None | |
| for device_name, device_data in per_device_data.items(): | |
| bs, seqlen, n_tok = device_data.ensure_coherence() | |
| if batch_size is None: | |
| batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok | |
| elif (bs, seqlen, n_tok) != ( | |
| batch_size, | |
| sequence_length, | |
| num_tokens_to_generate, | |
| ): | |
| fig.suptitle( | |
| f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} " | |
| f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}", | |
| color="white", | |
| fontsize=18, | |
| ) | |
| return None | |
| # TTFT Plot (top) | |
| ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft") | |
| draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks) | |
| # # ITL Plot (bottom) | |
| itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl") | |
| draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks) | |
| # Title and tight layout | |
| title = "\n".join( | |
| [ | |
| "Time to first token and inter-token latency (lower is better)", | |
| f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}", | |
| ] | |
| ) | |
| fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5) | |
| plt.tight_layout() | |
| # Add common legend with full text | |
| legend_labels, legend_colors, legend_hatches = [], [], [] | |
| for label, color, hatch in zip( | |
| ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"] | |
| ): | |
| if label not in legend_labels: | |
| legend_labels.append(label) | |
| legend_colors.append(color) | |
| legend_hatches.append(hatch) | |
| # Make sure all attn implementations are equally represented | |
| # implementations = {} | |
| # for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches): | |
| # impl = label.split(",")[0] | |
| # implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)] | |
| # n_max = max(len(impls) for impls in implementations.values()) | |
| # for label_color_pairs in implementations.values(): | |
| # for _ in range(len(label_color_pairs), n_max): | |
| # label_color_pairs.append(("", "#000000")) | |
| # legend_labels, legend_colors = zip(*sum(implementations.values(), [])) | |
| legend_handles = [ | |
| mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white") | |
| for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels) | |
| ] | |
| # Put a legend to the right of the current axis | |
| fig.legend( | |
| handles=legend_handles, | |
| loc="lower center", | |
| ncol=4, | |
| bbox_to_anchor=(0.515, -0.11), | |
| facecolor="black", | |
| edgecolor="white", | |
| labelcolor="white", | |
| fontsize=14, | |
| ) | |
| # Save plot to bytes with high DPI for crisp text | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150) | |
| buffer.seek(0) | |
| # Convert to base64 for HTML embedding | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| plt.close(fig) | |
| # Return HTML with embedded image - full page coverage | |
| html = f""" | |
| <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;"> | |
| <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" /> | |
| </div> | |
| """ | |
| return html | |
| def draw_bar_plot( | |
| ax: plt.Axes, | |
| bar_kwargs: dict, | |
| errors: list, | |
| ylabel: str, | |
| xticks: list[tuple[float, str]], | |
| adapt_ylim: bool = False, | |
| ) -> None: | |
| ax.set_facecolor("#000000") | |
| ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0) | |
| # Draw bars | |
| _ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3) | |
| # Add error bars | |
| ax.errorbar( | |
| bar_kwargs["x"], | |
| bar_kwargs["height"], | |
| yerr=errors, | |
| fmt="none", | |
| ecolor="white", | |
| alpha=0.8, | |
| elinewidth=1.5, | |
| capthick=1.5, | |
| capsize=4, | |
| zorder=4, | |
| ) | |
| # Set labels, ticks and grid | |
| ax.set_ylabel(ylabel, color="white", fontsize=16) | |
| ax.set_xticks([]) | |
| ax.tick_params(colors="white", labelsize=13) | |
| ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16) | |
| # Truncate axis to better fit the bars | |
| if adapt_ylim: | |
| new_ymin, new_ymax = 1e9, -1e9 | |
| for h, e in zip(bar_kwargs["height"], errors): | |
| new_ymin = min(new_ymin, 0.98 * (h - e)) | |
| new_ymax = max(new_ymax, 1.02 * (h + e)) | |
| ymin, ymax = ax.get_ylim() | |
| ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax)) | |