ror's picture
ror HF Staff
matplotlib
0eac02e
raw
history blame
5.35 kB
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64
from data import ModelBenchmarkData
# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()
DATA = ModelBenchmarkData("data.json")
def refresh_plot_data():
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
print(data)
return pd.DataFrame(data)
def load_css():
"""Load CSS styling."""
try:
with open("styles.css", "r") as f:
return f.read()
except FileNotFoundError:
return "body { background: #000; color: #fff; }"
def create_matplotlib_bar_charts():
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
# Create figure with dark theme - wider for side-by-side plots
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12))
fig.patch.set_facecolor('#000000')
# Prepare data
labels = data['label']
ttft_values = data['ttft']
tpot_values = data['tpot']
# Define color mapping based on configuration keywords
def get_color_for_config(label):
is_eager = 'eager' in label.lower()
is_sdpa = 'sdpa' in label.lower()
is_compiled = '_compiled' in label.lower()
if is_eager:
if is_compiled:
return '#FF4444' # Red for eager compiled
else:
return '#FF6B6B' # Light red for eager uncompiled
elif is_sdpa:
if is_compiled:
return '#4A90E2' # Blue for SDPA compiled
else:
return '#7BB3F0' # Light blue for SDPA uncompiled
else:
return '#FFD700' # Yellow for others
# Get colors for each bar
colors = [get_color_for_config(label) for label in labels]
# TTFT Plot (left)
ax1.set_facecolor('#000000')
bars1 = ax1.bar(range(len(labels)), ttft_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
ax1.set_xlabel('Model Configuration', color='white', fontsize=14)
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20)
ax1.set_xticks(range(len(labels)))
ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
rotation=45, ha='right', color='white', fontsize=10)
ax1.tick_params(colors='white')
ax1.grid(True, alpha=0.3, color='white')
# TPOT Plot (right)
ax2.set_facecolor('#000000')
bars2 = ax2.bar(range(len(labels)), tpot_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
ax2.set_xlabel('Model Configuration', color='white', fontsize=14)
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20)
ax2.set_xticks(range(len(labels)))
ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
rotation=45, ha='right', color='white', fontsize=10)
ax2.tick_params(colors='white')
ax2.grid(True, alpha=0.3, color='white')
# Tight layout to prevent label cutoff
plt.tight_layout()
# Save plot to bytes
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=100)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - almost full height
html = f"""
<div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" />
</div>
"""
return html
def refresh_plot():
"""Generate new matplotlib charts and update description."""
return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
with gr.Row():
# Sidebar
with gr.Column(scale=1, elem_classes=["sidebar"]):
gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])
# Main plot area
with gr.Column(elem_classes=["main-content"]):
plot = gr.HTML(
create_matplotlib_bar_charts(),
elem_classes=["plot-container"],
)
# Button click handler
summary_btn.click(fn=refresh_plot, outputs=[plot, description])
if __name__ == "__main__":
demo.launch()