File size: 8,822 Bytes
55c8a69
79e7993
55c8a69
 
 
 
dc41c89
79e7993
55c8a69
 
e1f4b73
 
 
 
 
 
79e7993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f4b73
 
 
 
 
55c8a69
165f130
 
9f6e83d
 
 
 
79e7993
 
9f6e83d
 
 
 
 
 
 
 
 
 
79e7993
165f130
 
 
 
79e7993
 
 
 
 
 
 
 
 
 
dc41c89
 
79e7993
dc41c89
 
 
165f130
dc41c89
 
 
79e7993
 
dc41c89
 
 
165f130
79e7993
dc41c89
 
 
 
 
 
 
 
79e7993
dc41c89
55c8a69
 
e1f4b73
79e7993
c4f3c79
79e7993
55c8a69
22cf82d
165f130
22cf82d
 
 
 
 
79e7993
 
 
 
 
22cf82d
 
 
79e7993
 
22cf82d
 
165f130
59644f0
165f130
22cf82d
55c8a69
59644f0
165f130
22cf82d
 
 
79e7993
 
 
 
 
 
 
dc41c89
55c8a69
 
79e7993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc41c89
 
79e7993
 
 
 
 
 
 
 
 
 
55c8a69
 
 
79e7993
55c8a69
 
 
 
 
 
e1f4b73
55c8a69
e1f4b73
 
55c8a69
 
 
dc41c89
 
c4f3c79
79e7993
c4f3c79
 
 
 
 
 
79e7993
 
dc41c89
79e7993
dc41c89
 
79e7993
 
 
 
 
 
 
 
 
 
dc41c89
22cf82d
79e7993
dc41c89
79e7993
dc41c89
 
c4f3c79
 
 
 
 
79e7993
c4f3c79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import io
import numpy as np
import base64

from plot_utils import get_color_for_config
from data import load_data, ModelBenchmarkData


def reorder_data(per_scenario_data: dict) -> dict:
    keys = list(per_scenario_data.keys())

    def sorting_fn(key: str) -> float:
        cfg = per_scenario_data[key]["config"]
        attn_implementation = cfg["attn_implementation"]
        attn_impl_prio = {
            "flash_attention_2": 0,
            "sdpa": 1,
            "eager": 2,
            "flex_attention": 3,
        }[attn_implementation]
        sdpa_backend_prio = {
            None: -1,
            "flash_attention": 0,
            "math": 1,
            "efficient_attention": 2,
            "cudnn_attention": 3,
        }[cfg["sdpa_backend"]]
        return (
            attn_impl_prio,
            sdpa_backend_prio,
            cfg["kernelize"],
            cfg["compile_mode"] is not None,
        )

    keys.sort(key=sorting_fn)
    per_scenario_data = {k: per_scenario_data[k] for k in keys}
    return per_scenario_data


def infer_bar_label(config: dict) -> str:
    """Format legend labels to be more readable."""
    if config["attn_implementation"] == "eager":
        attn_implementation = "Eager"
    elif config["attn_implementation"] == "flash_attention_2":
        attn_implementation = "Flash attention"
    elif config["attn_implementation"] == "flex_attention":
        attn_implementation = "Flex attention"
    elif config["attn_implementation"] == "sdpa":
        attn_implementation = {
            "flash_attention": "SDPA (flash attention)",
            "efficient_attention": "SDPA (efficient_attention)",
            "cudnn_attention": "SDPA (cudnn)",
            "math": "SDPA (math)",
        }.get(config["sdpa_backend"], "SDPA (unknown backend)")
    else:
        attn_implementation = "Unknown"

    compile = "compiled" if config["compile_mode"] is not None else "no compile"
    kernels = "kernelized" if config["kernelize"] else "no kernels"
    return f"{attn_implementation}, {compile}, {kernels}"


def infer_bar_hatch(config: dict) -> str:
    if config["compile_mode"] is not None:
        return "/"
    else:
        return ""


def make_bar_kwargs(
    per_device_data: dict[str, ModelBenchmarkData], key: str
) -> tuple[dict, list]:
    # Prepare accumulators
    current_x = 0
    bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []}
    errors_bars = []
    x_ticks = []

    for device_name, device_data in per_device_data.items():
        per_scenario_data = device_data.get_bar_plot_data()
        per_scenario_data = reorder_data(per_scenario_data)
        device_xs = []

        for scenario_name, scenario_data in per_scenario_data.items():
            bar_kwargs["x"].append(current_x)
            bar_kwargs["height"].append(np.median(scenario_data[key]))
            bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
            bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
            bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"]))
            errors_bars.append(np.std(scenario_data[key]))
            device_xs.append(current_x)
            current_x += 1

        x_ticks.append((np.mean(device_xs), device_name))
        current_x += 1.5
    return bar_kwargs, errors_bars, x_ticks


def create_matplotlib_bar_plot() -> None:
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""

    # Create figure with dark theme - maximum size for full screen
    plt.style.use("dark_background")
    fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True)  # used to be 30, 16
    fig.patch.set_facecolor("#000000")

    # Load data and ensure coherence
    per_device_data = load_data()
    batch_size, sequence_length, num_tokens_to_generate = None, None, None
    for device_name, device_data in per_device_data.items():
        bs, seqlen, n_tok = device_data.ensure_coherence()
        if batch_size is None:
            batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
        elif (bs, seqlen, n_tok) != (
            batch_size,
            sequence_length,
            num_tokens_to_generate,
        ):
            fig.suptitle(
                f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
                f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
                color="white",
                fontsize=18,
            )
            return None

    # TTFT Plot (top)
    ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
    draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)

    # # ITL Plot (bottom)
    itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
    draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)

    # Title and tight layout
    title = "\n".join(
        [
            "Time to first token and inter-token latency (lower is better)",
            f"Batch size: {batch_size},  sequence length: {sequence_length},  new tokens: {num_tokens_to_generate}",
        ]
    )
    fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5)
    plt.tight_layout()

    # Add common legend with full text
    legend_labels, legend_colors, legend_hatches = [], [], []
    for label, color, hatch in zip(
        ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"]
    ):
        if label not in legend_labels:
            legend_labels.append(label)
            legend_colors.append(color)
            legend_hatches.append(hatch)

    # Make sure all attn implementations are equally represented
    # implementations = {}
    # for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches):
    #     impl = label.split(",")[0]
    #     implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)]

    # n_max = max(len(impls) for impls in implementations.values())
    # for label_color_pairs in implementations.values():
    #     for _ in range(len(label_color_pairs), n_max):
    #         label_color_pairs.append(("", "#000000"))

    # legend_labels, legend_colors = zip(*sum(implementations.values(), []))

    legend_handles = [
        mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white")
        for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels)
    ]

    # Put a legend to the right of the current axis
    fig.legend(
        handles=legend_handles,
        loc="lower center",
        ncol=4,
        bbox_to_anchor=(0.515, -0.11),
        facecolor="black",
        edgecolor="white",
        labelcolor="white",
        fontsize=14,
    )

    # Save plot to bytes with high DPI for crisp text
    buffer = io.BytesIO()
    plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)

    # Return HTML with embedded image - full page coverage
    html = f"""
    <div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
        <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
    </div>
    """
    return html


def draw_bar_plot(
    ax: plt.Axes,
    bar_kwargs: dict,
    errors: list,
    ylabel: str,
    xticks: list[tuple[float, str]],
    adapt_ylim: bool = False,
) -> None:
    ax.set_facecolor("#000000")
    ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0)
    # Draw bars
    _ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3)
    # Add error bars
    ax.errorbar(
        bar_kwargs["x"],
        bar_kwargs["height"],
        yerr=errors,
        fmt="none",
        ecolor="white",
        alpha=0.8,
        elinewidth=1.5,
        capthick=1.5,
        capsize=4,
        zorder=4,
    )
    # Set labels, ticks and grid
    ax.set_ylabel(ylabel, color="white", fontsize=16)
    ax.set_xticks([])
    ax.tick_params(colors="white", labelsize=13)
    ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
    # Truncate axis to better fit the bars
    if adapt_ylim:
        new_ymin, new_ymax = 1e9, -1e9
        for h, e in zip(bar_kwargs["height"], errors):
            new_ymin = min(new_ymin, 0.98 * (h - e))
            new_ymax = max(new_ymax, 1.02 * (h + e))
        ymin, ymax = ax.get_ylim()
        ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))