File size: 6,968 Bytes
55c8a69
 
 
 
 
dc41c89
 
55c8a69
 
e1f4b73
 
 
 
 
 
 
 
 
 
 
 
 
55c8a69
165f130
 
9f6e83d
 
 
 
 
 
 
 
 
 
 
 
 
 
165f130
 
 
 
 
 
dc41c89
 
55c8a69
dc41c89
 
 
165f130
dc41c89
 
 
 
 
 
 
 
165f130
dc41c89
 
 
 
 
 
 
 
 
55c8a69
 
e1f4b73
55c8a69
c4f3c79
55c8a69
 
22cf82d
165f130
22cf82d
 
 
 
 
 
 
 
 
aa7e786
22cf82d
 
165f130
55c8a69
165f130
22cf82d
55c8a69
dc41c89
165f130
22cf82d
 
 
 
 
 
 
c4f3c79
dc41c89
55c8a69
 
165f130
 
 
dc41c89
 
e1f4b73
22cf82d
e1f4b73
55c8a69
 
 
 
 
 
 
 
 
 
 
e1f4b73
55c8a69
e1f4b73
 
55c8a69
 
 
dc41c89
 
c4f3c79
 
 
 
 
 
 
 
dc41c89
c4f3c79
dc41c89
 
 
 
 
165f130
dc41c89
22cf82d
dc41c89
 
 
 
 
c4f3c79
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import matplotlib.pyplot as plt
import io
import numpy as np
import base64

from plot_utils import get_color_for_config
from data import load_data


def reorder_data(per_scenario_data: dict) -> dict:
    keys = list(per_scenario_data.keys())

    def sorting_fn(key: str) -> float:
        cfg = per_scenario_data[key]["config"]
        attn_implementation = cfg["attn_implementation"]
        attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
        return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]

    keys.sort(key=sorting_fn)
    per_scenario_data = {k: per_scenario_data[k] for k in keys}
    return per_scenario_data


def infer_bar_label(config: dict) -> str:
    """Format legend labels to be more readable."""
    if config["attn_implementation"] == "eager":
        attn_implementation = "Eager"
    elif config["attn_implementation"] == "flash_attention_2":
        attn_implementation = "Flash attention"
    elif config["attn_implementation"] == "sdpa":
        attn_implementation = {
            "flash_attention": "SDPA (flash attention)",
            "efficient_attention": "SDPA (efficient_attention)",
            "cudnn_attention": "SDPA (cudnn)",
            "math": "SDPA (math)",
        }.get(config["sdpa_backend"], "SDPA (unknown backend)")
    else:
        attn_implementation = "Unknown"

    compile = "compiled" if config["compilation"] else "no compile"
    kernels = "kernelized" if config["kernelize"] else "no kernels"
    return f"{attn_implementation}, {compile}, {kernels}"


def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
    # Prepare accumulators
    current_x = 0
    bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
    errors_bars = []
    x_ticks = []

    for device_name, device_data in per_device_data.items():
        per_scenario_data = device_data.get_bar_plot_data()
        per_scenario_data = reorder_data(per_scenario_data)
        device_xs = []
        
        for scenario_name, scenario_data in per_scenario_data.items():    
            bar_kwargs["x"].append(current_x)
            bar_kwargs["height"].append(np.median(scenario_data[key]))
            bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
            bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
            errors_bars.append(np.std(scenario_data[key]))
            device_xs.append(current_x)
            current_x += 1

        x_ticks.append((np.mean(device_xs), device_name))
        current_x += 1.5
    return bar_kwargs, errors_bars, x_ticks

def create_matplotlib_bar_plot() -> None:
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""

    # Create figure with dark theme - maximum size for full screen
    plt.style.use('dark_background')
    fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True)  # used to be 30, 16
    fig.patch.set_facecolor('#000000')

    # Load data and ensure coherence
    per_device_data = load_data()
    batch_size, sequence_length, num_tokens_to_generate = None, None, None
    for device_name, device_data in per_device_data.items():
        bs, seqlen, n_tok = device_data.ensure_coherence()
        if batch_size is None:
            batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
        elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
            fig.suptitle(
                f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
                f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
                color='white', fontsize=18
            )
            return None

    # TTFT Plot (left)
    ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
    draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)

    # # ITL Plot (right)
    itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
    draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)

    # Title and tight layout
    title = "\n".join([
        "Time to first token and inter-token latency (lower is better)",
        f"Batch size: {batch_size},  sequence length: {sequence_length},  new tokens: {num_tokens_to_generate}",
    ])
    fig.suptitle(title, color='white', fontsize=20, y=1.005, linespacing=1.5)
    plt.tight_layout()

    # Add common legend with full text
    unique_bars = len(ttft_bars["label"]) // 2
    legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars] 
    legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]

    # Put a legend to the right of the current axis
    fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
               bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
               labelcolor='white', fontsize=14)

    # Save plot to bytes with high DPI for crisp text
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', facecolor='#000000',
                bbox_inches='tight', dpi=150)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)

    # Return HTML with embedded image - full page coverage
    html = f"""
    <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
        <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
    </div>
    """
    return html


def draw_bar_plot(
    ax: plt.Axes, 
    bar_kwargs: dict,
    errors: list,
    ylabel: str,
    xticks: list[tuple[float, str]],
    adapt_ylim: bool = False,
) -> None:
    ax.set_facecolor('#000000')
    ax.grid(True, alpha=0.3, color='white', axis='y', zorder=0)
    # Draw bars
    _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
    # Add error bars
    ax.errorbar(
        bar_kwargs["x"], bar_kwargs["height"], yerr=errors, 
        fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
    )
    # Set labels, ticks and grid
    ax.set_ylabel(ylabel, color='white', fontsize=16)
    ax.set_xticks([])
    ax.tick_params(colors='white', labelsize=13)
    ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
    # Truncate axis to better fit the bars
    if adapt_ylim:
        new_ymin, new_ymax = 1e9, -1e9
        for h, e in zip(bar_kwargs["height"], errors):
            new_ymin = min(new_ymin, 0.98 * (h - e))
            new_ymax = max(new_ymax, 1.02 * (h + e))
        ymin, ymax = ax.get_ylim() 
        ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))