Spaces:
Sleeping
Sleeping
File size: 5,347 Bytes
e83d737 d786867 0eac02e d786867 46ba2c6 0eac02e 46ba2c6 e83d737 9aa1b3d e83d737 9aa1b3d e83d737 0eac02e 9aa1b3d 0eac02e e83d737 9aa1b3d e83d737 9aa1b3d e83d737 9aa1b3d e83d737 9aa1b3d 29b7ec1 0eac02e e83d737 9aa1b3d e83d737 d786867 46ba2c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64
from data import ModelBenchmarkData
# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()
DATA = ModelBenchmarkData("data.json")
def refresh_plot_data():
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
print(data)
return pd.DataFrame(data)
def load_css():
"""Load CSS styling."""
try:
with open("styles.css", "r") as f:
return f.read()
except FileNotFoundError:
return "body { background: #000; color: #fff; }"
def create_matplotlib_bar_charts():
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
# Create figure with dark theme - wider for side-by-side plots
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12))
fig.patch.set_facecolor('#000000')
# Prepare data
labels = data['label']
ttft_values = data['ttft']
tpot_values = data['tpot']
# Define color mapping based on configuration keywords
def get_color_for_config(label):
is_eager = 'eager' in label.lower()
is_sdpa = 'sdpa' in label.lower()
is_compiled = '_compiled' in label.lower()
if is_eager:
if is_compiled:
return '#FF4444' # Red for eager compiled
else:
return '#FF6B6B' # Light red for eager uncompiled
elif is_sdpa:
if is_compiled:
return '#4A90E2' # Blue for SDPA compiled
else:
return '#7BB3F0' # Light blue for SDPA uncompiled
else:
return '#FFD700' # Yellow for others
# Get colors for each bar
colors = [get_color_for_config(label) for label in labels]
# TTFT Plot (left)
ax1.set_facecolor('#000000')
bars1 = ax1.bar(range(len(labels)), ttft_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
ax1.set_xlabel('Model Configuration', color='white', fontsize=14)
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20)
ax1.set_xticks(range(len(labels)))
ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
rotation=45, ha='right', color='white', fontsize=10)
ax1.tick_params(colors='white')
ax1.grid(True, alpha=0.3, color='white')
# TPOT Plot (right)
ax2.set_facecolor('#000000')
bars2 = ax2.bar(range(len(labels)), tpot_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
ax2.set_xlabel('Model Configuration', color='white', fontsize=14)
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20)
ax2.set_xticks(range(len(labels)))
ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
rotation=45, ha='right', color='white', fontsize=10)
ax2.tick_params(colors='white')
ax2.grid(True, alpha=0.3, color='white')
# Tight layout to prevent label cutoff
plt.tight_layout()
# Save plot to bytes
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=100)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - almost full height
html = f"""
<div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" />
</div>
"""
return html
def refresh_plot():
"""Generate new matplotlib charts and update description."""
return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
with gr.Row():
# Sidebar
with gr.Column(scale=1, elem_classes=["sidebar"]):
gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])
# Main plot area
with gr.Column(elem_classes=["main-content"]):
plot = gr.HTML(
create_matplotlib_bar_charts(),
elem_classes=["plot-container"],
)
# Button click handler
summary_btn.click(fn=refresh_plot, outputs=[plot, description])
if __name__ == "__main__":
demo.launch()
|