Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import io | |
| import base64 | |
| from data import ModelBenchmarkData | |
| # Configure matplotlib for better performance | |
| matplotlib.use('Agg') | |
| plt.ioff() | |
| DATA = ModelBenchmarkData("data.json") | |
| def refresh_plot_data(): | |
| data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False) | |
| print(data) | |
| return pd.DataFrame(data) | |
| def load_css(): | |
| """Load CSS styling.""" | |
| try: | |
| with open("styles.css", "r") as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return "body { background: #000; color: #fff; }" | |
| def create_matplotlib_bar_charts(): | |
| """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" | |
| data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False) | |
| # Create figure with dark theme - wider for side-by-side plots | |
| plt.style.use('dark_background') | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12)) | |
| fig.patch.set_facecolor('#000000') | |
| # Prepare data | |
| labels = data['label'] | |
| ttft_values = data['ttft'] | |
| tpot_values = data['tpot'] | |
| # Define color mapping based on configuration keywords | |
| def get_color_for_config(label): | |
| is_eager = 'eager' in label.lower() | |
| is_sdpa = 'sdpa' in label.lower() | |
| is_compiled = '_compiled' in label.lower() | |
| if is_eager: | |
| if is_compiled: | |
| return '#FF4444' # Red for eager compiled | |
| else: | |
| return '#FF6B6B' # Light red for eager uncompiled | |
| elif is_sdpa: | |
| if is_compiled: | |
| return '#4A90E2' # Blue for SDPA compiled | |
| else: | |
| return '#7BB3F0' # Light blue for SDPA uncompiled | |
| else: | |
| return '#FFD700' # Yellow for others | |
| # Get colors for each bar | |
| colors = [get_color_for_config(label) for label in labels] | |
| # TTFT Plot (left) | |
| ax1.set_facecolor('#000000') | |
| bars1 = ax1.bar(range(len(labels)), ttft_values, | |
| color=colors, width=1.0, edgecolor='white', linewidth=1) | |
| ax1.set_xlabel('Model Configuration', color='white', fontsize=14) | |
| ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14) | |
| ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20) | |
| ax1.set_xticks(range(len(labels))) | |
| ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels], | |
| rotation=45, ha='right', color='white', fontsize=10) | |
| ax1.tick_params(colors='white') | |
| ax1.grid(True, alpha=0.3, color='white') | |
| # TPOT Plot (right) | |
| ax2.set_facecolor('#000000') | |
| bars2 = ax2.bar(range(len(labels)), tpot_values, | |
| color=colors, width=1.0, edgecolor='white', linewidth=1) | |
| ax2.set_xlabel('Model Configuration', color='white', fontsize=14) | |
| ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14) | |
| ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20) | |
| ax2.set_xticks(range(len(labels))) | |
| ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels], | |
| rotation=45, ha='right', color='white', fontsize=10) | |
| ax2.tick_params(colors='white') | |
| ax2.grid(True, alpha=0.3, color='white') | |
| # Tight layout to prevent label cutoff | |
| plt.tight_layout() | |
| # Save plot to bytes | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', facecolor='#000000', | |
| bbox_inches='tight', dpi=100) | |
| buffer.seek(0) | |
| # Convert to base64 for HTML embedding | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| plt.close(fig) | |
| # Return HTML with embedded image - almost full height | |
| html = f""" | |
| <div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;"> | |
| <img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" /> | |
| </div> | |
| """ | |
| return html | |
| def refresh_plot(): | |
| """Generate new matplotlib charts and update description.""" | |
| return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo: | |
| with gr.Row(): | |
| # Sidebar | |
| with gr.Column(scale=1, elem_classes=["sidebar"]): | |
| gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"]) | |
| description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"]) | |
| summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"]) | |
| # Main plot area | |
| with gr.Column(elem_classes=["main-content"]): | |
| plot = gr.HTML( | |
| create_matplotlib_bar_charts(), | |
| elem_classes=["plot-container"], | |
| ) | |
| # Button click handler | |
| summary_btn.click(fn=refresh_plot, outputs=[plot, description]) | |
| if __name__ == "__main__": | |
| demo.launch() | |