File size: 6,565 Bytes
d73296f
e83d737
d786867
0eac02e
 
 
 
d786867
46ba2c6
 
0eac02e
 
 
 
46ba2c6
 
 
 
 
 
 
 
 
e83d737
 
9aa1b3d
e83d737
 
9aa1b3d
e83d737
 
 
0eac02e
 
 
 
32319bd
0eac02e
32319bd
0eac02e
 
 
 
 
 
 
d73296f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eac02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73296f
0ad3140
 
0eac02e
d73296f
 
 
 
 
 
 
0eac02e
4648f45
0eac02e
0ad3140
0eac02e
 
 
 
 
d73296f
0eac02e
 
d73296f
 
 
 
 
 
 
0eac02e
4648f45
0eac02e
0ad3140
0eac02e
 
 
0ad3140
 
 
 
 
32319bd
0ad3140
 
0eac02e
0ad3140
0eac02e
32319bd
0eac02e
 
32319bd
0eac02e
 
 
 
 
 
32319bd
0eac02e
32319bd
 
0eac02e
 
 
 
9aa1b3d
0eac02e
 
e83d737
9aa1b3d
 
e83d737
9aa1b3d
e83d737
 
9aa1b3d
 
e83d737
9aa1b3d
29b7ec1
0eac02e
 
 
e83d737
 
9aa1b3d
 
e83d737
d786867
46ba2c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from math import e
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64

from data import ModelBenchmarkData

# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()


DATA = ModelBenchmarkData("data.json")


def refresh_plot_data():
    data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
    print(data)
    return pd.DataFrame(data)


def load_css():
    """Load CSS styling."""
    try:
        with open("styles.css", "r") as f:
            return f.read()
    except FileNotFoundError:
        return "body { background: #000; color: #fff; }"

def create_matplotlib_bar_charts():
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
    data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)

    # Create figure with dark theme - larger for more screen space
    plt.style.use('dark_background')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 16))
    fig.patch.set_facecolor('#000000')

    # Prepare data
    labels = data['label']
    ttft_values = data['ttft']
    tpot_values = data['tpot']

    # Calculate error bars (standard deviation) for each configuration
    import numpy as np
    ttft_errors = []
    tpot_errors = []

    # Get raw data to calculate standard deviations
    raw_data = DATA.data
    for label in labels:
        if label in raw_data:
            # Calculate std dev for TTFT
            ttft_raw = [d['wall_time'] for d in raw_data[label]['ttft']]
            ttft_errors.append(float(np.std(ttft_raw)))

            # Calculate std dev for TPOT
            tpot_raw = [d['wall_time'] for d in raw_data[label]['tpot']]
            tpot_errors.append(float(np.std(tpot_raw)))
        else:
            ttft_errors.append(0)
            tpot_errors.append(0)

    # Define color mapping based on configuration keywords
    def get_color_for_config(label):
        is_eager = 'eager' in label.lower()
        is_sdpa = 'sdpa' in label.lower()
        is_compiled = '_compiled' in label.lower()
        if is_eager:
            if is_compiled:
                return '#FF4444'  # Red for eager compiled
            else:
                return '#FF6B6B'  # Light red for eager uncompiled
        elif is_sdpa:
            if is_compiled:
                return '#4A90E2'  # Blue for SDPA compiled
            else:
                return '#7BB3F0'  # Light blue for SDPA uncompiled
        else:
            return '#FFD700'  # Yellow for others

    # Get colors for each bar
    colors = [get_color_for_config(label) for label in labels]

    # TTFT Plot (left)
    ax1.set_facecolor('#000000')
    _ = ax1.bar(range(len(labels)), ttft_values,
                    color=colors, width=1.0, edgecolor='white', linewidth=1,
                    label=[label[:15] + '...' if len(label) > 15 else label for label in labels])

    # Add error bars for TTFT
    ax1.errorbar(
        range(len(labels)), ttft_values, yerr=ttft_errors,
        fmt='none', ecolor='white', alpha=0.8,
        elinewidth=1.5, capthick=1.5, capsize=4,
    )

    ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
    ax1.set_title('Time to first token (lower is better)', color='white', fontsize=16, pad=20)

    ax1.set_xticks([])
    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.3, color='white')

    # TPOT Plot (right)
    ax2.set_facecolor('#000000')
    _ = ax2.bar(range(len(labels)), tpot_values,
                    color=colors, width=1.0, edgecolor='white', linewidth=1)

    # Add error bars for TPOT
    ax2.errorbar(
        range(len(labels)), tpot_values, yerr=tpot_errors,
        fmt='none', ecolor='white', alpha=0.8,
        elinewidth=1.5, capthick=1.5, capsize=4,
    )

    ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
    ax2.set_title('Time per output token (lower is better)', color='white', fontsize=16, pad=20)

    ax2.set_xticks([])
    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.3, color='white')

    # Add common legend with full text
    legend_labels = labels  # Use full labels without truncation
    legend_handles = [plt.Rectangle((0,0),1,1, color=color, edgecolor='white') for color in colors]
    fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
               bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
               labelcolor='white', fontsize=12)

    # Tight layout with spacing between subplots and extra bottom space for legend
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.3, bottom=0.075)

    # Save plot to bytes with high DPI for crisp text
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', facecolor='#000000',
                bbox_inches='tight', dpi=130)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)

    # Return HTML with embedded image - full height
    html = f"""
    <div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
        <img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
    </div>
    """
    return html

def refresh_plot():
    """Generate new matplotlib charts and update description."""
    return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"

# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
    with gr.Row():
        # Sidebar
        with gr.Column(scale=1, elem_classes=["sidebar"]):
            gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
            description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
            summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])

        # Main plot area
        with gr.Column(elem_classes=["main-content"]):
            plot = gr.HTML(
                create_matplotlib_bar_charts(),
                elem_classes=["plot-container"],
            )

    # Button click handler
    summary_btn.click(fn=refresh_plot, outputs=[plot, description])

if __name__ == "__main__":
    demo.launch()