File size: 5,950 Bytes
55c8a69
 
 
 
 
dc41c89
 
55c8a69
 
e1f4b73
 
 
 
 
 
 
 
 
 
 
 
 
55c8a69
165f130
 
 
 
 
 
 
 
 
 
 
 
 
dc41c89
 
55c8a69
dc41c89
 
 
165f130
dc41c89
 
 
 
 
 
 
 
165f130
dc41c89
 
 
 
 
 
 
 
 
55c8a69
 
e1f4b73
55c8a69
dc41c89
55c8a69
 
165f130
 
 
 
 
 
 
55c8a69
165f130
dc41c89
55c8a69
dc41c89
165f130
dc41c89
e1f4b73
dc41c89
 
 
 
55c8a69
 
165f130
 
 
dc41c89
 
e1f4b73
dc41c89
e1f4b73
55c8a69
 
 
 
 
 
 
 
 
 
 
e1f4b73
55c8a69
e1f4b73
 
55c8a69
 
 
dc41c89
 
 
 
 
 
 
 
 
 
165f130
dc41c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import matplotlib.pyplot as plt
import io
import numpy as np
import base64

from plot_utils import get_color_for_config
from data import load_data


def reorder_data(per_scenario_data: dict) -> dict:
    keys = list(per_scenario_data.keys())

    def sorting_fn(key: str) -> float:
        cfg = per_scenario_data[key]["config"]
        attn_implementation = cfg["attn_implementation"]
        attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
        return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]

    keys.sort(key=sorting_fn)
    per_scenario_data = {k: per_scenario_data[k] for k in keys}
    return per_scenario_data


def infer_bar_label(config: dict) -> str:
    """Format legend labels to be more readable."""
    attn_implementation = {
        "flash_attention_2": "Flash attention",
        "sdpa": "SDPA",
        "eager": "Eager",
    }[config["attn_implementation"]]
    compile = "compiled" if config["compilation"] else "no compile"
    kernels = "kernelized" if config["kernelize"] else "no kernels"
    return f"{attn_implementation}, {compile}, {kernels}"


def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
    # Prepare accumulators
    current_x = 0
    bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
    errors_bars = []
    x_ticks = []

    for device_name, device_data in per_device_data.items():
        per_scenario_data = device_data.get_bar_plot_data()
        per_scenario_data = reorder_data(per_scenario_data)
        device_xs = []
        
        for scenario_name, scenario_data in per_scenario_data.items():    
            bar_kwargs["x"].append(current_x)
            bar_kwargs["height"].append(np.median(scenario_data[key]))
            bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
            bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
            errors_bars.append(np.std(scenario_data[key]))
            device_xs.append(current_x)
            current_x += 1

        x_ticks.append((np.mean(device_xs), device_name))
        current_x += 1.5
    return bar_kwargs, errors_bars, x_ticks

def create_matplotlib_bar_plot() -> None:
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""

    # Create figure with dark theme - maximum size for full screen
    plt.style.use('dark_background')
    fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True) 
    fig.patch.set_facecolor('#000000')

    # Load and sanitize data
    per_device_data = load_data()
    batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
    if len(set(batch_sizes.values())) > 1:
        fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
        return None

    # TTFT Plot (left)
    ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
    draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)

    # # ITL Plot (right)
    itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
    draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)

    # # E2E Plot (right)
    # e2e_bars, e2e_errors = make_bar_kwargs("e2e")
    # draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
    plt.tight_layout()

    # Add common legend with full text
    unique_bars = len(ttft_bars["label"]) // 2
    legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars] 
    legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]

    # Put a legend to the right of the current axis
    fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
               bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
               labelcolor='white', fontsize=14)

    # Save plot to bytes with high DPI for crisp text
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', facecolor='#000000',
                bbox_inches='tight', dpi=150)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)

    # Return HTML with embedded image - full page coverage
    html = f"""
    <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
        <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
    </div>
    """
    return html


def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
    ax.set_facecolor('#000000')
    ax.grid(True, alpha=0.2, color='white', zorder=0)
    # Draw bars
    _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
    # Add error bars
    ax.errorbar(
        bar_kwargs["x"], bar_kwargs["height"], yerr=errors, 
        fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
    )
    # Set labels and title
    ax.set_ylabel(ylabel, color='white', fontsize=16)
    ax.set_title(title, color='white', fontsize=18, pad=20)
    # Set ticks and grid
    ax.set_xticks([])
    ax.tick_params(colors='white', labelsize=13)
    ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
    # Truncate axis to better fit the bars
    new_ymin, new_ymax = 1e9, -1e9
    for h, e in zip(bar_kwargs["height"], errors):
        new_ymin = min(new_ymin, 0.98 * (h - e))
        new_ymax = max(new_ymax, 1.02 * (h + e))
    ymin, ymax = ax.get_ylim() 
    ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))