Spaces:
Sleeping
Sleeping
File size: 5,950 Bytes
55c8a69 dc41c89 55c8a69 e1f4b73 55c8a69 165f130 dc41c89 55c8a69 dc41c89 165f130 dc41c89 165f130 dc41c89 55c8a69 e1f4b73 55c8a69 dc41c89 55c8a69 165f130 55c8a69 165f130 dc41c89 55c8a69 dc41c89 165f130 dc41c89 e1f4b73 dc41c89 55c8a69 165f130 dc41c89 e1f4b73 dc41c89 e1f4b73 55c8a69 e1f4b73 55c8a69 e1f4b73 55c8a69 dc41c89 165f130 dc41c89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import matplotlib.pyplot as plt
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
attn_implementation = {
"flash_attention_2": "Flash attention",
"sdpa": "SDPA",
"eager": "Eager",
}[config["attn_implementation"]]
compile = "compiled" if config["compilation"] else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use('dark_background')
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
fig.patch.set_facecolor('#000000')
# Load and sanitize data
per_device_data = load_data()
batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
if len(set(batch_sizes.values())) > 1:
fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
return None
# TTFT Plot (left)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
# # ITL Plot (right)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
# # E2E Plot (right)
# e2e_bars, e2e_errors = make_bar_kwargs("e2e")
# draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
plt.tight_layout()
# Add common legend with full text
unique_bars = len(ttft_bars["label"]) // 2
legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
# Put a legend to the right of the current axis
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=14)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
ax.set_facecolor('#000000')
ax.grid(True, alpha=0.2, color='white', zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
)
# Set labels and title
ax.set_ylabel(ylabel, color='white', fontsize=16)
ax.set_title(title, color='white', fontsize=18, pad=20)
# Set ticks and grid
ax.set_xticks([])
ax.tick_params(colors='white', labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
|