Spaces:
Sleeping
Sleeping
Refactor
Browse files- .gitattributes +0 -35
- app.py +4 -133
- bar_plot.py +134 -0
- data.py +23 -21
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
-
from math import e
|
| 2 |
-
import pandas as pd
|
| 3 |
import gradio as gr
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import matplotlib
|
| 6 |
-
import io
|
| 7 |
-
import base64
|
| 8 |
|
| 9 |
from data import ModelBenchmarkData
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Configure matplotlib for better performance
|
| 12 |
matplotlib.use('Agg')
|
|
@@ -16,12 +14,6 @@ plt.ioff()
|
|
| 16 |
DATA = ModelBenchmarkData("data.json")
|
| 17 |
|
| 18 |
|
| 19 |
-
def refresh_plot_data():
|
| 20 |
-
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
|
| 21 |
-
print(data)
|
| 22 |
-
return pd.DataFrame(data)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
def load_css():
|
| 26 |
"""Load CSS styling."""
|
| 27 |
try:
|
|
@@ -30,132 +22,11 @@ def load_css():
|
|
| 30 |
except FileNotFoundError:
|
| 31 |
return "body { background: #000; color: #fff; }"
|
| 32 |
|
| 33 |
-
def create_matplotlib_bar_charts():
|
| 34 |
-
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
|
| 35 |
-
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
|
| 36 |
-
|
| 37 |
-
# Create figure with dark theme - larger for more screen space
|
| 38 |
-
plt.style.use('dark_background')
|
| 39 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 16))
|
| 40 |
-
fig.patch.set_facecolor('#000000')
|
| 41 |
-
|
| 42 |
-
# Prepare data
|
| 43 |
-
labels = data['label']
|
| 44 |
-
ttft_values = data['ttft']
|
| 45 |
-
tpot_values = data['tpot']
|
| 46 |
-
|
| 47 |
-
# Calculate error bars (standard deviation) for each configuration
|
| 48 |
-
import numpy as np
|
| 49 |
-
ttft_errors = []
|
| 50 |
-
tpot_errors = []
|
| 51 |
-
|
| 52 |
-
# Get raw data to calculate standard deviations
|
| 53 |
-
raw_data = DATA.data
|
| 54 |
-
for label in labels:
|
| 55 |
-
if label in raw_data:
|
| 56 |
-
# Calculate std dev for TTFT
|
| 57 |
-
ttft_raw = [d['wall_time'] for d in raw_data[label]['ttft']]
|
| 58 |
-
ttft_errors.append(float(np.std(ttft_raw)))
|
| 59 |
-
|
| 60 |
-
# Calculate std dev for TPOT
|
| 61 |
-
tpot_raw = [d['wall_time'] for d in raw_data[label]['tpot']]
|
| 62 |
-
tpot_errors.append(float(np.std(tpot_raw)))
|
| 63 |
-
else:
|
| 64 |
-
ttft_errors.append(0)
|
| 65 |
-
tpot_errors.append(0)
|
| 66 |
-
|
| 67 |
-
# Define color mapping based on configuration keywords
|
| 68 |
-
def get_color_for_config(label):
|
| 69 |
-
is_eager = 'eager' in label.lower()
|
| 70 |
-
is_sdpa = 'sdpa' in label.lower()
|
| 71 |
-
is_compiled = '_compiled' in label.lower()
|
| 72 |
-
if is_eager:
|
| 73 |
-
if is_compiled:
|
| 74 |
-
return '#FF4444' # Red for eager compiled
|
| 75 |
-
else:
|
| 76 |
-
return '#FF6B6B' # Light red for eager uncompiled
|
| 77 |
-
elif is_sdpa:
|
| 78 |
-
if is_compiled:
|
| 79 |
-
return '#4A90E2' # Blue for SDPA compiled
|
| 80 |
-
else:
|
| 81 |
-
return '#7BB3F0' # Light blue for SDPA uncompiled
|
| 82 |
-
else:
|
| 83 |
-
return '#FFD700' # Yellow for others
|
| 84 |
-
|
| 85 |
-
# Get colors for each bar
|
| 86 |
-
colors = [get_color_for_config(label) for label in labels]
|
| 87 |
-
|
| 88 |
-
# TTFT Plot (left)
|
| 89 |
-
ax1.set_facecolor('#000000')
|
| 90 |
-
_ = ax1.bar(range(len(labels)), ttft_values,
|
| 91 |
-
color=colors, width=1.0, edgecolor='white', linewidth=1,
|
| 92 |
-
label=[label[:15] + '...' if len(label) > 15 else label for label in labels])
|
| 93 |
-
|
| 94 |
-
# Add error bars for TTFT
|
| 95 |
-
ax1.errorbar(
|
| 96 |
-
range(len(labels)), ttft_values, yerr=ttft_errors,
|
| 97 |
-
fmt='none', ecolor='white', alpha=0.8,
|
| 98 |
-
elinewidth=1.5, capthick=1.5, capsize=4,
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
|
| 102 |
-
ax1.set_title('Time to first token (lower is better)', color='white', fontsize=16, pad=20)
|
| 103 |
-
|
| 104 |
-
ax1.set_xticks([])
|
| 105 |
-
ax1.tick_params(colors='white')
|
| 106 |
-
ax1.grid(True, alpha=0.3, color='white')
|
| 107 |
-
|
| 108 |
-
# TPOT Plot (right)
|
| 109 |
-
ax2.set_facecolor('#000000')
|
| 110 |
-
_ = ax2.bar(range(len(labels)), tpot_values,
|
| 111 |
-
color=colors, width=1.0, edgecolor='white', linewidth=1)
|
| 112 |
-
|
| 113 |
-
# Add error bars for TPOT
|
| 114 |
-
ax2.errorbar(
|
| 115 |
-
range(len(labels)), tpot_values, yerr=tpot_errors,
|
| 116 |
-
fmt='none', ecolor='white', alpha=0.8,
|
| 117 |
-
elinewidth=1.5, capthick=1.5, capsize=4,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
|
| 121 |
-
ax2.set_title('Time per output token (lower is better)', color='white', fontsize=16, pad=20)
|
| 122 |
-
|
| 123 |
-
ax2.set_xticks([])
|
| 124 |
-
ax2.tick_params(colors='white')
|
| 125 |
-
ax2.grid(True, alpha=0.3, color='white')
|
| 126 |
-
|
| 127 |
-
# Add common legend with full text
|
| 128 |
-
legend_labels = labels # Use full labels without truncation
|
| 129 |
-
legend_handles = [plt.Rectangle((0,0),1,1, color=color, edgecolor='white') for color in colors]
|
| 130 |
-
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
|
| 131 |
-
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
|
| 132 |
-
labelcolor='white', fontsize=12)
|
| 133 |
-
|
| 134 |
-
# Tight layout with spacing between subplots and extra bottom space for legend
|
| 135 |
-
plt.tight_layout()
|
| 136 |
-
plt.subplots_adjust(wspace=0.3, bottom=0.075)
|
| 137 |
-
|
| 138 |
-
# Save plot to bytes with high DPI for crisp text
|
| 139 |
-
buffer = io.BytesIO()
|
| 140 |
-
plt.savefig(buffer, format='png', facecolor='#000000',
|
| 141 |
-
bbox_inches='tight', dpi=130)
|
| 142 |
-
buffer.seek(0)
|
| 143 |
-
|
| 144 |
-
# Convert to base64 for HTML embedding
|
| 145 |
-
img_data = base64.b64encode(buffer.getvalue()).decode()
|
| 146 |
-
plt.close(fig)
|
| 147 |
|
| 148 |
-
# Return HTML with embedded image - full height
|
| 149 |
-
html = f"""
|
| 150 |
-
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
|
| 151 |
-
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
|
| 152 |
-
</div>
|
| 153 |
-
"""
|
| 154 |
-
return html
|
| 155 |
|
| 156 |
def refresh_plot():
|
| 157 |
"""Generate new matplotlib charts and update description."""
|
| 158 |
-
return
|
| 159 |
|
| 160 |
# Create Gradio interface
|
| 161 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
@@ -169,7 +40,7 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
|
|
| 169 |
# Main plot area
|
| 170 |
with gr.Column(elem_classes=["main-content"]):
|
| 171 |
plot = gr.HTML(
|
| 172 |
-
|
| 173 |
elem_classes=["plot-container"],
|
| 174 |
)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import matplotlib
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from data import ModelBenchmarkData
|
| 6 |
+
from bar_plot import create_matplotlib_bar_plot
|
| 7 |
+
|
| 8 |
|
| 9 |
# Configure matplotlib for better performance
|
| 10 |
matplotlib.use('Agg')
|
|
|
|
| 14 |
DATA = ModelBenchmarkData("data.json")
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def load_css():
|
| 18 |
"""Load CSS styling."""
|
| 19 |
try:
|
|
|
|
| 22 |
except FileNotFoundError:
|
| 23 |
return "body { background: #000; color: #fff; }"
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def refresh_plot():
|
| 28 |
"""Generate new matplotlib charts and update description."""
|
| 29 |
+
return create_matplotlib_bar_plot(DATA.get_bar_plot_data()), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
|
| 30 |
|
| 31 |
# Create Gradio interface
|
| 32 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
|
|
| 40 |
# Main plot area
|
| 41 |
with gr.Column(elem_classes=["main-content"]):
|
| 42 |
plot = gr.HTML(
|
| 43 |
+
create_matplotlib_bar_plot(DATA.get_bar_plot_data()),
|
| 44 |
elem_classes=["plot-container"],
|
| 45 |
)
|
| 46 |
|
bar_plot.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import io
|
| 3 |
+
import numpy as np
|
| 4 |
+
import base64
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Color manipulation functions
|
| 8 |
+
def hex_to_rgb(hex_color):
|
| 9 |
+
hex_color = hex_color.lstrip('#')
|
| 10 |
+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 11 |
+
return r, g, b
|
| 12 |
+
|
| 13 |
+
def increase_brightness(r, g, b, factor):
|
| 14 |
+
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
|
| 15 |
+
|
| 16 |
+
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
|
| 17 |
+
gray = 0.299 * r + 0.587 * g + 0.114 * b
|
| 18 |
+
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
|
| 19 |
+
|
| 20 |
+
def rgb_to_hex(r, g, b):
|
| 21 |
+
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
|
| 22 |
+
return f"#{r:02x}{g:02x}{b:02x}"
|
| 23 |
+
|
| 24 |
+
# Color assignment function
|
| 25 |
+
def get_color_for_config(config):
|
| 26 |
+
|
| 27 |
+
# Determine the main hue for the attention implementation
|
| 28 |
+
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
|
| 29 |
+
if attn_implementation == "eager":
|
| 30 |
+
main_hue = "#FF6B6B"
|
| 31 |
+
elif attn_implementation == "sdpa":
|
| 32 |
+
main_hue = {
|
| 33 |
+
None: "#4A90E2",
|
| 34 |
+
"math": "#408DDBFF",
|
| 35 |
+
"flash_attention": "#28767EFF",
|
| 36 |
+
"efficient_attention": "#605895FF",
|
| 37 |
+
"cudnn_attention": "#774AE2FF",
|
| 38 |
+
}[sdpa_backend]
|
| 39 |
+
elif attn_implementation == "flash_attention_2":
|
| 40 |
+
main_hue = "#FFD700"
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
|
| 43 |
+
|
| 44 |
+
# Apply color modifications for compilation and kernelization
|
| 45 |
+
r, g, b = hex_to_rgb(main_hue)
|
| 46 |
+
if config["compilation"]:
|
| 47 |
+
r, g, b = increase_brightness(r, g, b, 0.3)
|
| 48 |
+
if config["kernelize"]:
|
| 49 |
+
r, g, b = increase_saturation(r, g, b, 0.8)
|
| 50 |
+
|
| 51 |
+
# Return the color as a hex string
|
| 52 |
+
return rgb_to_hex(r, g, b)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
|
| 56 |
+
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
|
| 57 |
+
errors = []
|
| 58 |
+
for i, (name, data) in enumerate(per_scenario_data.items()):
|
| 59 |
+
bar_kwargs["x"].append(i)
|
| 60 |
+
bar_kwargs["height"].append(np.median(data[key]))
|
| 61 |
+
bar_kwargs["color"].append(get_color_for_config(data["config"]))
|
| 62 |
+
bar_kwargs["label"].append(name)
|
| 63 |
+
errors.append(np.std(data[key]))
|
| 64 |
+
return bar_kwargs, errors
|
| 65 |
+
|
| 66 |
+
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
|
| 67 |
+
ax.set_facecolor('#000000')
|
| 68 |
+
# Draw bars
|
| 69 |
+
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
|
| 70 |
+
# Add error bars
|
| 71 |
+
ax.errorbar(
|
| 72 |
+
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 73 |
+
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
|
| 74 |
+
)
|
| 75 |
+
# Set labels and title
|
| 76 |
+
ax.set_ylabel(ylabel, color='white', fontsize=14)
|
| 77 |
+
ax.set_title(title, color='white', fontsize=16, pad=20)
|
| 78 |
+
# Set ticks and grid
|
| 79 |
+
ax.set_xticks([])
|
| 80 |
+
ax.tick_params(colors='white')
|
| 81 |
+
ax.grid(True, alpha=0.3, color='white')
|
| 82 |
+
# Truncate axis to better fit the bars
|
| 83 |
+
# new_ymin, new_ymax = 1e9, -1e9
|
| 84 |
+
# for h, e in zip(bar_kwargs["height"], errors):
|
| 85 |
+
# new_ymin = min(new_ymin, 0.98 * (h - e))
|
| 86 |
+
# new_ymax = max(new_ymax, 1.02 * (h + e))
|
| 87 |
+
# ymin, ymax = ax.get_ylim()
|
| 88 |
+
# ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def create_matplotlib_bar_plot(per_scenario_data: dict):
|
| 92 |
+
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
|
| 93 |
+
|
| 94 |
+
# Create figure with dark theme - larger for more screen space
|
| 95 |
+
plt.style.use('dark_background')
|
| 96 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 12))
|
| 97 |
+
fig.patch.set_facecolor('#000000')
|
| 98 |
+
|
| 99 |
+
# TTFT Plot (left)
|
| 100 |
+
ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
|
| 101 |
+
draw_bar_plot(ax1, ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
|
| 102 |
+
|
| 103 |
+
# TPOT Plot (right)
|
| 104 |
+
itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
|
| 105 |
+
draw_bar_plot(ax2, itl_bars, itl_errors, "Time per output token (lower is better)", "ITL (seconds)")
|
| 106 |
+
|
| 107 |
+
# Add common legend with full text
|
| 108 |
+
legend_labels = ttft_bars["label"] # Use full labels without truncation
|
| 109 |
+
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
|
| 110 |
+
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
|
| 111 |
+
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
|
| 112 |
+
labelcolor='white', fontsize=12)
|
| 113 |
+
|
| 114 |
+
# Tight layout with spacing between subplots and extra bottom space for legend
|
| 115 |
+
# plt.subplots_adjust(wspace=0.3, bottom=0.075)
|
| 116 |
+
|
| 117 |
+
# Save plot to bytes with high DPI for crisp text
|
| 118 |
+
buffer = io.BytesIO()
|
| 119 |
+
plt.savefig(buffer, format='png', facecolor='#000000',
|
| 120 |
+
bbox_inches='tight', dpi=150)
|
| 121 |
+
buffer.seek(0)
|
| 122 |
+
|
| 123 |
+
# Convert to base64 for HTML embedding
|
| 124 |
+
img_data = base64.b64encode(buffer.getvalue()).decode()
|
| 125 |
+
plt.close(fig)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Return HTML with embedded image - full height
|
| 129 |
+
html = f"""
|
| 130 |
+
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
|
| 131 |
+
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
|
| 132 |
+
</div>
|
| 133 |
+
"""
|
| 134 |
+
return html
|
data.py
CHANGED
|
@@ -1,13 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import numpy as np
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def estimate_from_measures(measures: list[float], estimator: str) -> float:
|
| 6 |
-
if estimator == "median":
|
| 7 |
-
return float(np.median(measures))
|
| 8 |
-
elif estimator == "mean":
|
| 9 |
-
return float(np.mean(measures))
|
| 10 |
-
raise ValueError(f"Invalid estimator: {estimator}")
|
| 11 |
|
| 12 |
|
| 13 |
class ModelBenchmarkData:
|
|
@@ -16,16 +9,25 @@ class ModelBenchmarkData:
|
|
| 16 |
with open(json_path, "r") as f:
|
| 17 |
self.data = json.load(f)
|
| 18 |
|
| 19 |
-
def
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import numpy as np
|
| 3 |
+
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class ModelBenchmarkData:
|
|
|
|
| 9 |
with open(json_path, "r") as f:
|
| 10 |
self.data = json.load(f)
|
| 11 |
|
| 12 |
+
def compute_e2e_latency(self, measures: dict) -> tuple[float, Optional[float]]:
|
| 13 |
+
return measures["e2e_latency"]
|
| 14 |
+
|
| 15 |
+
def compute_ttft(self, measures: dict) -> float:
|
| 16 |
+
return measures["t_tokens"][0] - measures["wall_time_start"]
|
| 17 |
+
|
| 18 |
+
def compute_itl(self, measures: dict) -> Optional[float]:
|
| 19 |
+
if len(measures["t_tokens"]) < 2:
|
| 20 |
+
return None
|
| 21 |
+
delta_t = measures["t_tokens"][-1] - measures["t_tokens"][0]
|
| 22 |
+
num_tokens = len(measures["t_tokens"]) - 1
|
| 23 |
+
return delta_t / num_tokens
|
| 24 |
+
|
| 25 |
+
def get_bar_plot_data(self) -> dict:
|
| 26 |
+
per_scenario_data = {}
|
| 27 |
+
for i, (cfg_name, data) in enumerate(self.data.items()):
|
| 28 |
+
per_scenario_data[cfg_name] = {
|
| 29 |
+
"ttft": [self.compute_ttft(d) for d in data["measures"]],
|
| 30 |
+
"itl": [self.compute_itl(d) for d in data["measures"]],
|
| 31 |
+
"config": data["metadata"]["config"],
|
| 32 |
+
}
|
| 33 |
+
return per_scenario_data
|