File size: 5,347 Bytes
e83d737
d786867
0eac02e
 
 
 
d786867
46ba2c6
 
0eac02e
 
 
 
46ba2c6
 
 
 
 
 
 
 
 
e83d737
 
9aa1b3d
e83d737
 
9aa1b3d
e83d737
 
 
0eac02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aa1b3d
0eac02e
 
e83d737
9aa1b3d
 
e83d737
9aa1b3d
e83d737
 
9aa1b3d
 
e83d737
9aa1b3d
29b7ec1
0eac02e
 
 
e83d737
 
9aa1b3d
 
e83d737
d786867
46ba2c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64

from data import ModelBenchmarkData

# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()


DATA = ModelBenchmarkData("data.json")


def refresh_plot_data():
    data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
    print(data)
    return pd.DataFrame(data)


def load_css():
    """Load CSS styling."""
    try:
        with open("styles.css", "r") as f:
            return f.read()
    except FileNotFoundError:
        return "body { background: #000; color: #fff; }"

def create_matplotlib_bar_charts():
    """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
    data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)

    # Create figure with dark theme - wider for side-by-side plots
    plt.style.use('dark_background')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12))
    fig.patch.set_facecolor('#000000')

    # Prepare data
    labels = data['label']
    ttft_values = data['ttft']
    tpot_values = data['tpot']

    # Define color mapping based on configuration keywords
    def get_color_for_config(label):
        is_eager = 'eager' in label.lower()
        is_sdpa = 'sdpa' in label.lower()
        is_compiled = '_compiled' in label.lower()
        if is_eager:
            if is_compiled:
                return '#FF4444'  # Red for eager compiled
            else:
                return '#FF6B6B'  # Light red for eager uncompiled
        elif is_sdpa:
            if is_compiled:
                return '#4A90E2'  # Blue for SDPA compiled
            else:
                return '#7BB3F0'  # Light blue for SDPA uncompiled
        else:
            return '#FFD700'  # Yellow for others

    # Get colors for each bar
    colors = [get_color_for_config(label) for label in labels]

    # TTFT Plot (left)
    ax1.set_facecolor('#000000')
    bars1 = ax1.bar(range(len(labels)), ttft_values,
                    color=colors, width=1.0, edgecolor='white', linewidth=1)

    ax1.set_xlabel('Model Configuration', color='white', fontsize=14)
    ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
    ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20)

    ax1.set_xticks(range(len(labels)))
    ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
                        rotation=45, ha='right', color='white', fontsize=10)

    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.3, color='white')

    # TPOT Plot (right)
    ax2.set_facecolor('#000000')
    bars2 = ax2.bar(range(len(labels)), tpot_values,
                    color=colors, width=1.0, edgecolor='white', linewidth=1)

    ax2.set_xlabel('Model Configuration', color='white', fontsize=14)
    ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
    ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20)

    ax2.set_xticks(range(len(labels)))
    ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
                        rotation=45, ha='right', color='white', fontsize=10)

    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.3, color='white')

    # Tight layout to prevent label cutoff
    plt.tight_layout()

    # Save plot to bytes
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png', facecolor='#000000',
                bbox_inches='tight', dpi=100)
    buffer.seek(0)

    # Convert to base64 for HTML embedding
    img_data = base64.b64encode(buffer.getvalue()).decode()
    plt.close(fig)

    # Return HTML with embedded image - almost full height
    html = f"""
    <div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;">
        <img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" />
    </div>
    """
    return html

def refresh_plot():
    """Generate new matplotlib charts and update description."""
    return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"

# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
    with gr.Row():
        # Sidebar
        with gr.Column(scale=1, elem_classes=["sidebar"]):
            gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
            description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
            summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])

        # Main plot area
        with gr.Column(elem_classes=["main-content"]):
            plot = gr.HTML(
                create_matplotlib_bar_charts(),
                elem_classes=["plot-container"],
            )

    # Button click handler
    summary_btn.click(fn=refresh_plot, outputs=[plot, description])

if __name__ == "__main__":
    demo.launch()