Spaces:
Sleeping
Sleeping
File size: 6,565 Bytes
d73296f e83d737 d786867 0eac02e d786867 46ba2c6 0eac02e 46ba2c6 e83d737 9aa1b3d e83d737 9aa1b3d e83d737 0eac02e 32319bd 0eac02e 32319bd 0eac02e d73296f 0eac02e d73296f 0ad3140 0eac02e d73296f 0eac02e 4648f45 0eac02e 0ad3140 0eac02e d73296f 0eac02e d73296f 0eac02e 4648f45 0eac02e 0ad3140 0eac02e 0ad3140 32319bd 0ad3140 0eac02e 0ad3140 0eac02e 32319bd 0eac02e 32319bd 0eac02e 32319bd 0eac02e 32319bd 0eac02e 9aa1b3d 0eac02e e83d737 9aa1b3d e83d737 9aa1b3d e83d737 9aa1b3d e83d737 9aa1b3d 29b7ec1 0eac02e e83d737 9aa1b3d e83d737 d786867 46ba2c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from math import e
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64
from data import ModelBenchmarkData
# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()
DATA = ModelBenchmarkData("data.json")
def refresh_plot_data():
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
print(data)
return pd.DataFrame(data)
def load_css():
"""Load CSS styling."""
try:
with open("styles.css", "r") as f:
return f.read()
except FileNotFoundError:
return "body { background: #000; color: #fff; }"
def create_matplotlib_bar_charts():
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
# Create figure with dark theme - larger for more screen space
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 16))
fig.patch.set_facecolor('#000000')
# Prepare data
labels = data['label']
ttft_values = data['ttft']
tpot_values = data['tpot']
# Calculate error bars (standard deviation) for each configuration
import numpy as np
ttft_errors = []
tpot_errors = []
# Get raw data to calculate standard deviations
raw_data = DATA.data
for label in labels:
if label in raw_data:
# Calculate std dev for TTFT
ttft_raw = [d['wall_time'] for d in raw_data[label]['ttft']]
ttft_errors.append(float(np.std(ttft_raw)))
# Calculate std dev for TPOT
tpot_raw = [d['wall_time'] for d in raw_data[label]['tpot']]
tpot_errors.append(float(np.std(tpot_raw)))
else:
ttft_errors.append(0)
tpot_errors.append(0)
# Define color mapping based on configuration keywords
def get_color_for_config(label):
is_eager = 'eager' in label.lower()
is_sdpa = 'sdpa' in label.lower()
is_compiled = '_compiled' in label.lower()
if is_eager:
if is_compiled:
return '#FF4444' # Red for eager compiled
else:
return '#FF6B6B' # Light red for eager uncompiled
elif is_sdpa:
if is_compiled:
return '#4A90E2' # Blue for SDPA compiled
else:
return '#7BB3F0' # Light blue for SDPA uncompiled
else:
return '#FFD700' # Yellow for others
# Get colors for each bar
colors = [get_color_for_config(label) for label in labels]
# TTFT Plot (left)
ax1.set_facecolor('#000000')
_ = ax1.bar(range(len(labels)), ttft_values,
color=colors, width=1.0, edgecolor='white', linewidth=1,
label=[label[:15] + '...' if len(label) > 15 else label for label in labels])
# Add error bars for TTFT
ax1.errorbar(
range(len(labels)), ttft_values, yerr=ttft_errors,
fmt='none', ecolor='white', alpha=0.8,
elinewidth=1.5, capthick=1.5, capsize=4,
)
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
ax1.set_title('Time to first token (lower is better)', color='white', fontsize=16, pad=20)
ax1.set_xticks([])
ax1.tick_params(colors='white')
ax1.grid(True, alpha=0.3, color='white')
# TPOT Plot (right)
ax2.set_facecolor('#000000')
_ = ax2.bar(range(len(labels)), tpot_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
# Add error bars for TPOT
ax2.errorbar(
range(len(labels)), tpot_values, yerr=tpot_errors,
fmt='none', ecolor='white', alpha=0.8,
elinewidth=1.5, capthick=1.5, capsize=4,
)
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
ax2.set_title('Time per output token (lower is better)', color='white', fontsize=16, pad=20)
ax2.set_xticks([])
ax2.tick_params(colors='white')
ax2.grid(True, alpha=0.3, color='white')
# Add common legend with full text
legend_labels = labels # Use full labels without truncation
legend_handles = [plt.Rectangle((0,0),1,1, color=color, edgecolor='white') for color in colors]
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=12)
# Tight layout with spacing between subplots and extra bottom space for legend
plt.tight_layout()
plt.subplots_adjust(wspace=0.3, bottom=0.075)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=130)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full height
html = f"""
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
</div>
"""
return html
def refresh_plot():
"""Generate new matplotlib charts and update description."""
return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
with gr.Row():
# Sidebar
with gr.Column(scale=1, elem_classes=["sidebar"]):
gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])
# Main plot area
with gr.Column(elem_classes=["main-content"]):
plot = gr.HTML(
create_matplotlib_bar_charts(),
elem_classes=["plot-container"],
)
# Button click handler
summary_btn.click(fn=refresh_plot, outputs=[plot, description])
if __name__ == "__main__":
demo.launch()
|