Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import io | |
| import numpy as np | |
| import base64 | |
| from plot_utils import get_color_for_config | |
| from data import load_data | |
| def reorder_data(per_scenario_data: dict) -> dict: | |
| keys = list(per_scenario_data.keys()) | |
| def sorting_fn(key: str) -> float: | |
| cfg = per_scenario_data[key]["config"] | |
| attn_implementation = cfg["attn_implementation"] | |
| attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation] | |
| return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"] | |
| keys.sort(key=sorting_fn) | |
| per_scenario_data = {k: per_scenario_data[k] for k in keys} | |
| return per_scenario_data | |
| def infer_bar_label(config: dict) -> str: | |
| """Format legend labels to be more readable.""" | |
| if config["attn_implementation"] == "eager": | |
| attn_implementation = "Eager" | |
| elif config["attn_implementation"] == "flash_attention_2": | |
| attn_implementation = "Flash attention" | |
| elif config["attn_implementation"] == "sdpa": | |
| attn_implementation = { | |
| "flash_attention": "SDPA (flash attention)", | |
| "efficient_attention": "SDPA (efficient_attention)", | |
| "cudnn_attention": "SDPA (cudnn)", | |
| "math": "SDPA (math)", | |
| }.get(config["sdpa_backend"], "SDPA (unknown backend)") | |
| else: | |
| attn_implementation = "Unknown" | |
| compile = "compiled" if config["compilation"] else "no compile" | |
| kernels = "kernelized" if config["kernelize"] else "no kernels" | |
| return f"{attn_implementation}, {compile}, {kernels}" | |
| def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]: | |
| # Prepare accumulators | |
| current_x = 0 | |
| bar_kwargs = {"x": [], "height": [], "color": [], "label": []} | |
| errors_bars = [] | |
| x_ticks = [] | |
| for device_name, device_data in per_device_data.items(): | |
| per_scenario_data = device_data.get_bar_plot_data() | |
| per_scenario_data = reorder_data(per_scenario_data) | |
| device_xs = [] | |
| for scenario_name, scenario_data in per_scenario_data.items(): | |
| bar_kwargs["x"].append(current_x) | |
| bar_kwargs["height"].append(np.median(scenario_data[key])) | |
| bar_kwargs["color"].append(get_color_for_config(scenario_data["config"])) | |
| bar_kwargs["label"].append(infer_bar_label(scenario_data["config"])) | |
| errors_bars.append(np.std(scenario_data[key])) | |
| device_xs.append(current_x) | |
| current_x += 1 | |
| x_ticks.append((np.mean(device_xs), device_name)) | |
| current_x += 1.5 | |
| return bar_kwargs, errors_bars, x_ticks | |
| def create_matplotlib_bar_plot() -> None: | |
| """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" | |
| # Create figure with dark theme - maximum size for full screen | |
| plt.style.use('dark_background') | |
| fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True) | |
| fig.patch.set_facecolor('#000000') | |
| # Load data and ensure coherence | |
| per_device_data = load_data() | |
| batch_size, sequence_length, num_tokens_to_generate = None, None, None | |
| for device_name, device_data in per_device_data.items(): | |
| bs, seqlen, n_tok = device_data.ensure_coherence() | |
| if batch_size is None: | |
| batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok | |
| elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate): | |
| fig.suptitle( | |
| f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} " | |
| f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}", | |
| color='white', fontsize=18, pad=20 | |
| ) | |
| return None | |
| # TTFT Plot (left) | |
| ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft") | |
| draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks) | |
| # # ITL Plot (right) | |
| itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl") | |
| draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks) | |
| # Title and tight layout | |
| title = "\n".join([ | |
| "Time to first token and inter-token latency (lower is better)", | |
| f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}", | |
| ]) | |
| fig.suptitle(title, color='white', fontsize=20, y=1.005) | |
| plt.tight_layout() | |
| # Add common legend with full text | |
| unique_bars = len(ttft_bars["label"]) // 2 | |
| legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars] | |
| legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors] | |
| # Put a legend to the right of the current axis | |
| fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4, | |
| bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white', | |
| labelcolor='white', fontsize=14) | |
| # Save plot to bytes with high DPI for crisp text | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', facecolor='#000000', | |
| bbox_inches='tight', dpi=150) | |
| buffer.seek(0) | |
| # Convert to base64 for HTML embedding | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| plt.close(fig) | |
| # Return HTML with embedded image - full page coverage | |
| html = f""" | |
| <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;"> | |
| <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" /> | |
| </div> | |
| """ | |
| return html | |
| def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, ylabel: str, xticks: list[tuple[float, str]]): | |
| ax.set_facecolor('#000000') | |
| ax.grid(True, alpha=0.2, color='white', zorder=0) | |
| # Draw bars | |
| _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3) | |
| # Add error bars | |
| ax.errorbar( | |
| bar_kwargs["x"], bar_kwargs["height"], yerr=errors, | |
| fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4, | |
| ) | |
| # Set labels, ticks and grid | |
| ax.set_ylabel(ylabel, color='white', fontsize=16) | |
| ax.set_xticks([]) | |
| ax.tick_params(colors='white', labelsize=13) | |
| ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16) | |
| # Truncate axis to better fit the bars | |
| new_ymin, new_ymax = 1e9, -1e9 | |
| for h, e in zip(bar_kwargs["height"], errors): | |
| new_ymin = min(new_ymin, 0.98 * (h - e)) | |
| new_ymax = max(new_ymax, 1.02 * (h + e)) | |
| ymin, ymax = ax.get_ylim() | |
| ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax)) | |