Spaces:
Sleeping
Sleeping
matplotlib
Browse files
app.py
CHANGED
|
@@ -1,9 +1,16 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import gradio as gr
|
| 3 |
-
import
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from data import ModelBenchmarkData
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
DATA = ModelBenchmarkData("data.json")
|
| 9 |
|
|
@@ -22,9 +29,97 @@ def load_css():
|
|
| 22 |
except FileNotFoundError:
|
| 23 |
return "body { background: #000; color: #fff; }"
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def refresh_plot():
|
| 26 |
-
"""Generate new
|
| 27 |
-
return
|
| 28 |
|
| 29 |
# Create Gradio interface
|
| 30 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
@@ -37,14 +132,9 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
|
|
| 37 |
|
| 38 |
# Main plot area
|
| 39 |
with gr.Column(elem_classes=["main-content"]):
|
| 40 |
-
plot = gr.
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
tooltip="all",
|
| 44 |
-
height="100vh",
|
| 45 |
-
container=False,
|
| 46 |
-
show_fullscreen_button=True,
|
| 47 |
-
elem_classes=["plot-container"]
|
| 48 |
)
|
| 49 |
|
| 50 |
# Button click handler
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import gradio as gr
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib
|
| 5 |
+
import io
|
| 6 |
+
import base64
|
| 7 |
|
| 8 |
from data import ModelBenchmarkData
|
| 9 |
|
| 10 |
+
# Configure matplotlib for better performance
|
| 11 |
+
matplotlib.use('Agg')
|
| 12 |
+
plt.ioff()
|
| 13 |
+
|
| 14 |
|
| 15 |
DATA = ModelBenchmarkData("data.json")
|
| 16 |
|
|
|
|
| 29 |
except FileNotFoundError:
|
| 30 |
return "body { background: #000; color: #fff; }"
|
| 31 |
|
| 32 |
+
def create_matplotlib_bar_charts():
|
| 33 |
+
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
|
| 34 |
+
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
|
| 35 |
+
|
| 36 |
+
# Create figure with dark theme - wider for side-by-side plots
|
| 37 |
+
plt.style.use('dark_background')
|
| 38 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12))
|
| 39 |
+
fig.patch.set_facecolor('#000000')
|
| 40 |
+
|
| 41 |
+
# Prepare data
|
| 42 |
+
labels = data['label']
|
| 43 |
+
ttft_values = data['ttft']
|
| 44 |
+
tpot_values = data['tpot']
|
| 45 |
+
|
| 46 |
+
# Define color mapping based on configuration keywords
|
| 47 |
+
def get_color_for_config(label):
|
| 48 |
+
is_eager = 'eager' in label.lower()
|
| 49 |
+
is_sdpa = 'sdpa' in label.lower()
|
| 50 |
+
is_compiled = '_compiled' in label.lower()
|
| 51 |
+
if is_eager:
|
| 52 |
+
if is_compiled:
|
| 53 |
+
return '#FF4444' # Red for eager compiled
|
| 54 |
+
else:
|
| 55 |
+
return '#FF6B6B' # Light red for eager uncompiled
|
| 56 |
+
elif is_sdpa:
|
| 57 |
+
if is_compiled:
|
| 58 |
+
return '#4A90E2' # Blue for SDPA compiled
|
| 59 |
+
else:
|
| 60 |
+
return '#7BB3F0' # Light blue for SDPA uncompiled
|
| 61 |
+
else:
|
| 62 |
+
return '#FFD700' # Yellow for others
|
| 63 |
+
|
| 64 |
+
# Get colors for each bar
|
| 65 |
+
colors = [get_color_for_config(label) for label in labels]
|
| 66 |
+
|
| 67 |
+
# TTFT Plot (left)
|
| 68 |
+
ax1.set_facecolor('#000000')
|
| 69 |
+
bars1 = ax1.bar(range(len(labels)), ttft_values,
|
| 70 |
+
color=colors, width=1.0, edgecolor='white', linewidth=1)
|
| 71 |
+
|
| 72 |
+
ax1.set_xlabel('Model Configuration', color='white', fontsize=14)
|
| 73 |
+
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
|
| 74 |
+
ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20)
|
| 75 |
+
|
| 76 |
+
ax1.set_xticks(range(len(labels)))
|
| 77 |
+
ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
|
| 78 |
+
rotation=45, ha='right', color='white', fontsize=10)
|
| 79 |
+
|
| 80 |
+
ax1.tick_params(colors='white')
|
| 81 |
+
ax1.grid(True, alpha=0.3, color='white')
|
| 82 |
+
|
| 83 |
+
# TPOT Plot (right)
|
| 84 |
+
ax2.set_facecolor('#000000')
|
| 85 |
+
bars2 = ax2.bar(range(len(labels)), tpot_values,
|
| 86 |
+
color=colors, width=1.0, edgecolor='white', linewidth=1)
|
| 87 |
+
|
| 88 |
+
ax2.set_xlabel('Model Configuration', color='white', fontsize=14)
|
| 89 |
+
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
|
| 90 |
+
ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20)
|
| 91 |
+
|
| 92 |
+
ax2.set_xticks(range(len(labels)))
|
| 93 |
+
ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
|
| 94 |
+
rotation=45, ha='right', color='white', fontsize=10)
|
| 95 |
+
|
| 96 |
+
ax2.tick_params(colors='white')
|
| 97 |
+
ax2.grid(True, alpha=0.3, color='white')
|
| 98 |
+
|
| 99 |
+
# Tight layout to prevent label cutoff
|
| 100 |
+
plt.tight_layout()
|
| 101 |
+
|
| 102 |
+
# Save plot to bytes
|
| 103 |
+
buffer = io.BytesIO()
|
| 104 |
+
plt.savefig(buffer, format='png', facecolor='#000000',
|
| 105 |
+
bbox_inches='tight', dpi=100)
|
| 106 |
+
buffer.seek(0)
|
| 107 |
+
|
| 108 |
+
# Convert to base64 for HTML embedding
|
| 109 |
+
img_data = base64.b64encode(buffer.getvalue()).decode()
|
| 110 |
+
plt.close(fig)
|
| 111 |
+
|
| 112 |
+
# Return HTML with embedded image - almost full height
|
| 113 |
+
html = f"""
|
| 114 |
+
<div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;">
|
| 115 |
+
<img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" />
|
| 116 |
+
</div>
|
| 117 |
+
"""
|
| 118 |
+
return html
|
| 119 |
+
|
| 120 |
def refresh_plot():
|
| 121 |
+
"""Generate new matplotlib charts and update description."""
|
| 122 |
+
return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
|
| 123 |
|
| 124 |
# Create Gradio interface
|
| 125 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
|
|
| 132 |
|
| 133 |
# Main plot area
|
| 134 |
with gr.Column(elem_classes=["main-content"]):
|
| 135 |
+
plot = gr.HTML(
|
| 136 |
+
create_matplotlib_bar_charts(),
|
| 137 |
+
elem_classes=["plot-container"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
# Button click handler
|
data.py
CHANGED
|
@@ -17,12 +17,15 @@ class ModelBenchmarkData:
|
|
| 17 |
self.data = json.load(f)
|
| 18 |
|
| 19 |
def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
|
| 20 |
-
aggregated_data = {"ttft": [], "tpot": [], "label": []}
|
| 21 |
time_key = "cuda_time" if use_cuda_time else "wall_time"
|
|
|
|
| 22 |
for cfg_name, data in self.data.items():
|
| 23 |
x_measures = [d[time_key] for d in data["ttft"]]
|
| 24 |
y_measures = [d[time_key] for d in data["tpot"]]
|
| 25 |
aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
|
| 26 |
aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
|
| 27 |
aggregated_data["label"].append(cfg_name)
|
|
|
|
|
|
|
| 28 |
return aggregated_data
|
|
|
|
| 17 |
self.data = json.load(f)
|
| 18 |
|
| 19 |
def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
|
| 20 |
+
aggregated_data = {"ttft": [], "tpot": [], "label": [], "position": []}
|
| 21 |
time_key = "cuda_time" if use_cuda_time else "wall_time"
|
| 22 |
+
position = 0
|
| 23 |
for cfg_name, data in self.data.items():
|
| 24 |
x_measures = [d[time_key] for d in data["ttft"]]
|
| 25 |
y_measures = [d[time_key] for d in data["tpot"]]
|
| 26 |
aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
|
| 27 |
aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
|
| 28 |
aggregated_data["label"].append(cfg_name)
|
| 29 |
+
aggregated_data["position"].append(position)
|
| 30 |
+
position += 1
|
| 31 |
return aggregated_data
|