import matplotlib.pyplot as plt import io import numpy as np import base64 from plot_utils import get_color_for_config from data import load_data def reorder_data(per_scenario_data: dict) -> dict: keys = list(per_scenario_data.keys()) def sorting_fn(key: str) -> float: cfg = per_scenario_data[key]["config"] attn_implementation = cfg["attn_implementation"] attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation] return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"] keys.sort(key=sorting_fn) per_scenario_data = {k: per_scenario_data[k] for k in keys} return per_scenario_data def infer_bar_label(config: dict) -> str: """Format legend labels to be more readable.""" if config["attn_implementation"] == "eager": attn_implementation = "Eager" elif config["attn_implementation"] == "flash_attention_2": attn_implementation = "Flash attention" elif config["attn_implementation"] == "sdpa": attn_implementation = { "flash_attention": "SDPA (flash attention)", "efficient_attention": "SDPA (efficient_attention)", "cudnn_attention": "SDPA (cudnn)", "math": "SDPA (math)", }.get(config["sdpa_backend"], "SDPA (unknown backend)") else: attn_implementation = "Unknown" compile = "compiled" if config["compilation"] else "no compile" kernels = "kernelized" if config["kernelize"] else "no kernels" return f"{attn_implementation}, {compile}, {kernels}" def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]: # Prepare accumulators current_x = 0 bar_kwargs = {"x": [], "height": [], "color": [], "label": []} errors_bars = [] x_ticks = [] for device_name, device_data in per_device_data.items(): per_scenario_data = device_data.get_bar_plot_data() per_scenario_data = reorder_data(per_scenario_data) device_xs = [] for scenario_name, scenario_data in per_scenario_data.items(): bar_kwargs["x"].append(current_x) bar_kwargs["height"].append(np.median(scenario_data[key])) bar_kwargs["color"].append(get_color_for_config(scenario_data["config"])) bar_kwargs["label"].append(infer_bar_label(scenario_data["config"])) errors_bars.append(np.std(scenario_data[key])) device_xs.append(current_x) current_x += 1 x_ticks.append((np.mean(device_xs), device_name)) current_x += 1.5 return bar_kwargs, errors_bars, x_ticks def create_matplotlib_bar_plot() -> None: """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" # Create figure with dark theme - maximum size for full screen plt.style.use('dark_background') fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True) fig.patch.set_facecolor('#000000') # Load data and ensure coherence per_device_data = load_data() batch_size, sequence_length, num_tokens_to_generate = None, None, None for device_name, device_data in per_device_data.items(): bs, seqlen, n_tok = device_data.ensure_coherence() if batch_size is None: batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate): fig.suptitle( f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} " f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}", color='white', fontsize=18, pad=20 ) return None # TTFT Plot (left) ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft") draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks) # # ITL Plot (right) itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl") draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks) # Title and tight layout title = "\n".join([ "Time to first token and inter-token latency (lower is better)", f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}", ]) fig.suptitle(title, color='white', fontsize=20, y=1.005) plt.tight_layout() # Add common legend with full text unique_bars = len(ttft_bars["label"]) // 2 legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars] legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors] # Put a legend to the right of the current axis fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4, bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white', labelcolor='white', fontsize=14) # Save plot to bytes with high DPI for crisp text buffer = io.BytesIO() plt.savefig(buffer, format='png', facecolor='#000000', bbox_inches='tight', dpi=150) buffer.seek(0) # Convert to base64 for HTML embedding img_data = base64.b64encode(buffer.getvalue()).decode() plt.close(fig) # Return HTML with embedded image - full page coverage html = f"""