File size: 5,262 Bytes
55c8a69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import matplotlib.pyplot as plt
import io
import numpy as np
import base64


# Color manipulation functions
def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip('#')
    r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
    return r, g, b

def increase_brightness(r, g, b, factor):
    return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))

def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
    gray = 0.299 * r + 0.587 * g + 0.114 * b
    return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))

def rgb_to_hex(r, g, b):
    r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
    return f"#{r:02x}{g:02x}{b:02x}"

# Color assignment function
def get_color_for_config(config):
    
    # Determine the main hue for the attention implementation
    attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
    if attn_implementation == "eager":
        main_hue = "#FF6B6B"
    elif attn_implementation == "sdpa":
        main_hue = {
            None: "#4A90E2",
            "math": "#408DDBFF",
            "flash_attention": "#28767EFF",
            "efficient_attention": "#605895FF",
            "cudnn_attention": "#774AE2FF",
        }[sdpa_backend]
    elif attn_implementation == "flash_attention_2":
        main_hue = "#FFD700"
    else:
        raise ValueError(f"Unknown attention implementation: {attn_implementation}")

    # Apply color modifications for compilation and kernelization
    r, g, b = hex_to_rgb(main_hue)
    if config["compilation"]:
        r, g, b = increase_brightness(r, g, b, 0.3)
    if config["kernelize"]:
        r, g, b = increase_saturation(r, g, b, 0.8)
    
    # Return the color as a hex string
    return rgb_to_hex(r, g, b)


def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
    bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
    errors = []
    for i, (name, data) in enumerate(per_scenario_data.items()):    
        bar_kwargs["x"].append(i)
        bar_kwargs["height"].append(np.median(data[key]))
        bar_kwargs["color"].append(get_color_for_config(data["config"]))
        bar_kwargs["label"].append(name)
        errors.append(np.std(data[key]))
    return bar_kwargs, errors

def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
    ax.set_facecolor('#000000')
    # Draw bars
    _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
    # Add error bars
    ax.errorbar(
        bar_kwargs["x"], bar_kwargs["height"], yerr=errors, 
        fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
    )
    # Set labels and title
    ax.set_ylabel(ylabel, color='white', fontsize=14)
    ax.set_title(title, color='white', fontsize=16, pad=20)
    # Set ticks and grid
    ax.set_xticks([])
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.3, color='white')
    # Truncate axis to better fit the bars
    # new_ymin, new_ymax = 1e9, -1e9
    # for h, e in zip(bar_kwargs["height"], errors):
    #     new_ymin = min(new_ymin, 0.98 * (h - e))
    #     new_ymax = max(new_ymax, 1.02 * (h + e))
    # ymin, ymax = ax.get_ylim()
    # ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))


def create_matplotlib_bar_plot(per_scenario_data: dict):
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""

    # Create figure with dark theme - larger for more screen space
    plt.style.use('dark_background')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12))
    fig.patch.set_facecolor('#000000')

    # TTFT Plot (left)
    ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
    draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")

    # TPOT Plot (right)
    itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
    draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)")

    # Add common legend with full text
    legend_labels = ttft_bars["label"]  # Use full labels without truncation
    legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
    fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
               bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
               labelcolor='white', fontsize=12)

    # Tight layout with spacing between subplots and extra bottom space for legend
    # plt.subplots_adjust(wspace=0.3, bottom=0.075)

    # Save plot to bytes with high DPI for crisp text
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', facecolor='#000000',
                bbox_inches='tight', dpi=150)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)


    # Return HTML with embedded image - full height
    html = f"""
    <div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
        <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
    </div>
    """
    return html