ffasr / analytics.py
whojavumusic's picture
changes and cleanup in text
af0dc0a
Raw
History Blame Contribute Delete
44.3 kB
"""
Interactive Plotly charts for the Analysis tab (legend toggles traces, zoom, pan, hover).
Radar = compare models on normalized axes. Middle chart = WER across scenarios (one trace per model).
Bar chart = grouped WER by scenario for top models.
"""
from __future__ import annotations
import threading
from typing import Sequence
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.colors import qualitative
from metrics_config import (
HEATMAP_SCENARIO_KEYS,
LIVE_SCENARIO_KEYS,
SCENARIO_METRICS,
heatmap_label_for_key,
metric_by_key,
)
# Mean WER over realistic SNR splits (Pareto X-axis and ranking).
SNR_WER_KEYS: tuple[str, ...] = (
"wer_realistic_high_snr",
"wer_realistic_mid_snr",
"wer_realistic_low_snr",
)
def _wer_pct(v) -> float | None:
"""Fractional WER (0–1) → percent for chart display."""
if v is None or (isinstance(v, float) and not np.isfinite(v)):
return None
try:
f = float(v)
except (TypeError, ValueError):
return None
if not np.isfinite(f):
return None
return f * 100.0
def _pareto_efficient_frontier_wer_rtf(
wers: Sequence[float], rtfs: Sequence[float]
) -> tuple[list[float], list[float]]:
"""Layer-1 envelope: X = WER (lower better), Y = RTF (higher better)."""
pts = [
(float(w), float(r))
for w, r in zip(wers, rtfs)
if np.isfinite(w) and np.isfinite(r)
]
if not pts:
return [], []
frontier: list[tuple[float, float]] = []
for i, (wi, ri) in enumerate(pts):
dominated = False
for j, (wj, rj) in enumerate(pts):
if i == j:
continue
if wj <= wi and rj >= ri and (wj < wi or rj > ri):
dominated = True
break
if not dominated:
frontier.append((wi, ri))
frontier.sort(key=lambda p: p[0])
return [p[0] for p in frontier], [p[1] for p in frontier]
# Consistent height for Gradio Plot
_FIG_HEIGHT = 460
_TEMPLATE = "plotly_white"
# Weight-file size lookup for the "Memory" radar axis. Populated lazily on first render;
# falls back to NaN if the hub call fails (e.g. gated repo without token).
_model_size_cache: dict[str, float] = {}
_model_size_lock = threading.Lock()
_WEIGHT_EXTS = (".safetensors", ".bin", ".pt", ".pth", ".gguf", ".ckpt", ".msgpack", ".h5")
def _cached_model_size_mb(model_id: str) -> float:
with _model_size_lock:
if model_id in _model_size_cache:
return _model_size_cache[model_id]
total = 0
try:
from huggingface_hub import HfApi
info = HfApi().model_info(model_id, files_metadata=True)
for s in info.siblings or []:
name = getattr(s, "rfilename", "") or ""
size = getattr(s, "size", None) or 0
if name.lower().endswith(_WEIGHT_EXTS):
total += int(size)
except Exception:
total = 0
mb = float(total) / (1024.0 * 1024.0) if total > 0 else float("nan")
with _model_size_lock:
_model_size_cache[model_id] = mb
return mb
def _coerce_wer(v) -> float | None:
"""Non-negative finite WER, or None if missing / invalid."""
if v is None or v == "":
return None
try:
x = float(v)
return x if np.isfinite(x) and x >= 0 else None
except (TypeError, ValueError):
return None
def _coerce_rtf(v) -> float | None:
"""Non-negative finite RTF (audio sec / inference sec), or None if missing."""
if v is None or v == "":
return None
try:
x = float(v)
return x if np.isfinite(x) and x >= 0 else None
except (TypeError, ValueError):
return None
def snr_avg_wer(row: dict) -> float | None:
"""Mean WER over high / mid / low SNR scenarios (skip missing)."""
vals: list[float] = []
for k in SNR_WER_KEYS:
w = _coerce_wer(row.get(k))
if w is not None:
vals.append(w)
if not vals:
return None
return float(sum(vals) / len(vals))
def compute_pareto_layers(rows: list[dict]) -> dict[str, int | None]:
"""
Pareto non-dominated layers over (snr_avg_wer ↓, eval_rtf ↑).
Layer 1 = efficient frontier; peel and repeat. Missing WER or RTF → ``None``.
"""
pts: list[tuple[str, float, float]] = []
out: dict[str, int | None] = {}
for r in rows:
mid = (r.get("model_id") or "").strip()
if not mid:
continue
w = snr_avg_wer(r)
rtf = _coerce_rtf(r.get("eval_rtf"))
if w is None or rtf is None:
out[mid] = None
else:
pts.append((mid, w, rtf))
unassigned = list(pts)
layer = 1
while unassigned:
frontier: list[str] = []
for i, (mid_i, w_i, r_i) in enumerate(unassigned):
dominated = False
for j, (_, w_j, r_j) in enumerate(unassigned):
if i == j:
continue
if w_j <= w_i and r_j >= r_i and (w_j < w_i or r_j > r_i):
dominated = True
break
if not dominated:
frontier.append(mid_i)
frontier_set = set(frontier)
for mid in frontier:
out[mid] = layer
unassigned = [p for p in unassigned if p[0] not in frontier_set]
layer += 1
return out
def _avg_wer_for_row(row: dict) -> float:
"""Mean WER over live scenarios; ``inf`` when nothing populated."""
vals: list[float] = []
for k in LIVE_SCENARIO_KEYS:
w = _coerce_wer(row.get(k))
if w is not None:
vals.append(w)
if not vals:
return float("inf")
return sum(vals) / len(vals)
def sort_leaderboard_rows_inplace(rows: list[dict]) -> None:
"""Sort rows by Average WER ascending (lower WER first)."""
rows.sort(key=lambda r: (_avg_wer_for_row(r), (r.get("model_id") or "")))
def frontier_label(layer: int | None) -> str:
return str(layer) if layer is not None else "NA"
def _log_normalize(values: np.ndarray, *, inverse: bool, strength_floor: float = 0.1) -> np.ndarray:
"""
Log10-normalize an array across all models to a "strength" score (default display floor 0.1).
* `inverse=True` -> lower raw value is better (WER, memory). Strength = 1 - norm.
* `inverse=False` -> higher raw value is better (RTF). Strength = norm.
Returns NaN for entries whose raw value is missing or non-positive.
"""
v = np.asarray(values, dtype=float)
out = np.full_like(v, np.nan, dtype=float)
mask = np.isfinite(v) & (v > 0)
if not mask.any():
return out
logv = np.log10(np.maximum(v[mask], 1e-12))
lo = float(np.min(logv))
hi = float(np.max(logv))
if hi - lo < 1e-9:
out[mask] = 0.5
else:
out[mask] = (logv - lo) / (hi - lo)
if inverse:
out[mask] = 1.0 - out[mask]
fin = np.isfinite(out)
out[fin] = np.clip(np.maximum(out[fin], strength_floor), strength_floor, 1.0)
return out
# Axis spec: (strength_label, primary col, fallback col, inverse, raw_display_name, raw_unit)
# strength_label: chart axis label; always phrased so "farther = better".
# inverse : True if *lower* raw value is better (e.g. WER, memory); False if higher is better (RTF).
# raw_display_name: human-readable name of the underlying metric, shown in the hover tooltip.
# raw_unit : optional unit suffix for the raw value in the tooltip.
#
# Live far-field scenario axes + Speed (RTF) + Compactness (parameters / Hub size).
_RADAR_AXES: tuple[tuple[str, str, str | None, bool, str, str], ...] = (
("Near Field Speech", "wer_anechoic_speech", None, True, "Near field speech WER", ""),
("Lab Measured", "wer_lab_measured", None, True, "Lab measured WER", ""),
("Lab Simulated", "wer_lab_simulated", None, True, "Lab simulated WER", ""),
("High SNR", "wer_realistic_high_snr", None, True, "High SNR WER", ""),
("Mid SNR", "wer_realistic_mid_snr", None, True, "Mid SNR WER", ""),
("Low SNR", "wer_realistic_low_snr", None, True, "Low SNR WER", ""),
("Moving Sources", "wer_moving_sources", None, True, "Moving Sources WER", ""),
("Speed", "eval_rtf", None, False, "RTFx", "× realtime"),
# Compactness: smaller model -> more compact -> farther from centre.
# `inverse=True` makes `_log_normalize` emit `1 - norm`, so the smallest
# observed model maps to strength = 1.0 (radar edge) and the largest
# maps to 0.0 (centre).
#
# Primary signal: parameter count (millions), which is recorded for every
# backend by `attach_params` and is much more reliable than the Hub
# weight-file size lookup (which can return NaN for gated/private repos
# or repos that ship weights via LFS pointers without size metadata).
# `model_size_mb` is the fallback for legacy rows that pre-date param
# tracking.
("Compactness", "num_params_m", "model_size_mb", True, "Parameters", " M"),
)
# Hover-label override when a radar axis falls back from `primary` to
# `fallback`. Without this the tooltip would still claim the value is in the
# primary metric (e.g. "Parameters: 1500.0000 M") when in fact the rendered
# value came from the fallback column (`Model size: 1500.0000 MB`).
_RADAR_FALLBACK_DISPLAY: dict[str, tuple[str, str]] = {
"num_params_m": ("Model size", " MB"),
}
def _empty_fig(message: str) -> go.Figure:
fig = go.Figure()
fig.add_annotation(
text=message,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=14),
)
fig.update_layout(
template=_TEMPLATE,
height=400,
margin=dict(l=40, r=40, t=40, b=40),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
return fig
def _plotly_legend_config() -> dict:
"""Legend click toggles traces; double-click isolates."""
return dict(
legend=dict(
orientation="v",
yanchor="top",
y=1,
xanchor="left",
x=1.02,
font=dict(size=10),
),
hovermode="closest",
)
def _raw_to_analytics_df(raw: list[dict]) -> pd.DataFrame:
"""
Parse leaderboard rows; coerce scenario WER + timing columns to float (NaN if missing).
Computes:
* ``avg_wer``: mean of live scenario WERs present for that row (reference).
* ``model_size_mb``: total weight-file size fetched from the Hub (cached).
"""
from init import normalize_legacy_csv_row
if not raw:
return pd.DataFrame()
rows = [normalize_legacy_csv_row(dict(r)) for r in raw]
df = pd.DataFrame(rows)
if "model_id" not in df.columns:
return pd.DataFrame()
for m in SCENARIO_METRICS:
k = m.key
if k not in df.columns:
df[k] = np.nan
else:
df[k] = pd.to_numeric(df[k], errors="coerce")
for k in ("eval_wall_time_s", "eval_rtf", "eval_audio_seconds", "num_params"):
if k not in df.columns:
df[k] = np.nan
else:
df[k] = pd.to_numeric(df[k], errors="coerce")
# Millions of parameters; compact display.
df["num_params_m"] = df["num_params"] / 1e6
live_cols = [c for c in LIVE_SCENARIO_KEYS if c in df.columns]
df["avg_wer"] = df[live_cols].mean(axis=1, skipna=True) if live_cols else np.nan
layers = compute_pareto_layers(rows)
df["pareto_layer"] = [
layers.get((r.get("model_id") or "").strip()) for r in rows
]
df["snr_avg_wer"] = [snr_avg_wer(r) for r in rows]
df["model_size_mb"] = df["model_id"].apply(_cached_model_size_mb)
return df
def _sort_df_leaderboard_order(df: pd.DataFrame) -> pd.DataFrame:
"""Average WER ascending (lower is better)."""
if df.empty:
return df
d = df.copy()
if "avg_wer" in d.columns:
return d.sort_values("avg_wer", ascending=True, na_position="last")
return d
def available_metric_keys(df: pd.DataFrame) -> list[str]:
"""Scenario keys that have at least one non-NaN value."""
keys = []
for m in SCENARIO_METRICS:
if m.key in df.columns and df[m.key].notna().any():
keys.append(m.key)
return keys
# Brand-aligned colors per HF org / company. Lower-cased org prefix → hex.
# Unknown orgs fall back to a deterministic palette so that the same company
# keeps the same color across the Intelligence / Speed charts.
_COMPANY_COLORS: dict[str, str] = {
"nvidia": "#76B900", # NVIDIA green
"openai": "#111111", # OpenAI black
"openai-community": "#111111",
"qwen": "#615CED", # Qwen purple
"ibm-granite": "#0F62FE", # IBM blue
"ibm": "#0F62FE",
"coherelabs": "#39594D", # Cohere muted green
"cohere": "#39594D",
"efficient-speech": "#F2994A",
"usefulsensors": "#22A7F0",
"google": "#4285F4", # Google blue
"google-t5": "#4285F4",
"meta-llama": "#1877F2", # Meta blue
"facebook": "#1877F2",
"microsoft": "#00A4EF",
"deepseek-ai": "#1D6FE0",
"deepseek": "#1D6FE0",
"moonshotai": "#111111",
"mistralai": "#FF6F00",
"anthropic": "#CC9B73",
"x-ai": "#000000",
"xai": "#000000",
"zhipuai": "#1F8AFF",
"zai-org": "#1F8AFF",
}
_FALLBACK_PALETTE: tuple[str, ...] = (
"#1F77B4", "#FF7F0E", "#2CA02C", "#D62728", "#9467BD",
"#8C564B", "#E377C2", "#7F7F7F", "#BCBD22", "#17BECF",
)
def _company_key(model_id: str) -> str:
"""Normalized company key from a HF model id (`org/name` → `org`)."""
mid = (model_id or "").strip()
if "/" in mid:
return mid.split("/", 1)[0].lower()
return mid.lower()
def _company_color_map(model_ids: Sequence[str]) -> dict[str, str]:
"""Stable model_id → hex color mapping; same company shares a color."""
out: dict[str, str] = {}
fallback_assign: dict[str, str] = {}
unknown_companies: list[str] = []
for mid in model_ids:
comp = _company_key(mid)
if comp in _COMPANY_COLORS:
out[mid] = _COMPANY_COLORS[comp]
else:
if comp not in fallback_assign:
unknown_companies.append(comp)
out[mid] = "" # filled in below
# Deterministic palette assignment for unknown orgs (sorted alphabetically).
for i, comp in enumerate(sorted(set(unknown_companies))):
fallback_assign[comp] = _FALLBACK_PALETTE[i % len(_FALLBACK_PALETTE)]
for mid in model_ids:
if not out[mid]:
out[mid] = fallback_assign[_company_key(mid)]
return out
def _short_model_label(model_id: str, max_len: int = 26) -> str:
name = (model_id or "").split("/", 1)[-1]
return name if len(name) <= max_len else name[: max_len - 1] + "…"
def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
h = hex_color.lstrip("#")
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
def _mix(c: int, target: int, t: float) -> int:
return int(round(c + (target - c) * max(0.0, min(1.0, t))))
def _lighten(hex_color: str, t: float) -> str:
"""t in [0,1] -> mix toward white (t=0 returns the original color)."""
r, g, b = _hex_to_rgb(hex_color)
return "#{:02X}{:02X}{:02X}".format(_mix(r, 255, t), _mix(g, 255, t), _mix(b, 255, t))
def _darken(hex_color: str, t: float) -> str:
"""t in [0,1] -> mix toward black (t=0 returns the original color)."""
r, g, b = _hex_to_rgb(hex_color)
return "#{:02X}{:02X}{:02X}".format(_mix(r, 0, t), _mix(g, 0, t), _mix(b, 0, t))
def _to_rgba(hex_color: str, alpha: float) -> str:
r, g, b = _hex_to_rgb(hex_color)
return f"rgba({r},{g},{b},{alpha:.3f})"
def _company_shade_assignments(
model_ids: Sequence[str],
color_map: dict[str, str],
) -> tuple[list[str], list[str], list[float]]:
"""
For bars whose company appears multiple times in this chart, vary the
luminance + opacity so adjacent same-color bars remain distinguishable.
Returns parallel lists (one entry per model_id):
* fill rgba color (with per-bar alpha baked in)
* border hex color (a darker shade of the base brand color)
* text opacity (kept ≥ 0.85 so in-bar labels stay legible)
"""
counts: dict[str, int] = {}
for mid in model_ids:
counts[_company_key(mid)] = counts.get(_company_key(mid), 0) + 1
seen: dict[str, int] = {}
fills: list[str] = []
borders: list[str] = []
text_alphas: list[float] = []
for mid in model_ids:
base = color_map[mid]
comp = _company_key(mid)
n = counts[comp]
idx = seen.get(comp, 0)
seen[comp] = idx + 1
if n > 1:
# First instance keeps base color; subsequent ones lighten ~12% per step
# (capped at ~45%) and reduce alpha slightly so the boundary is obvious
# even when two same-company bars sit next to each other.
lighten_t = min(0.45, 0.12 * idx)
alpha = max(0.65, 1.0 - 0.08 * idx)
shaded = _lighten(base, lighten_t)
fills.append(_to_rgba(shaded, alpha))
borders.append(_darken(base, 0.35))
text_alphas.append(max(0.9, 1.0 - 0.05 * idx))
else:
fills.append(_to_rgba(base, 1.0))
borders.append(_darken(base, 0.35))
text_alphas.append(1.0)
return fills, borders, text_alphas
def _plot_company_bars(
labels: Sequence[str],
values: Sequence[float],
model_ids: Sequence[str],
*,
title: str,
subtitle: str,
value_fmt: str,
hover_value_fmt: str,
hover_metric_label: str,
) -> go.Figure:
"""
Shared renderer for the Intelligence / Speed vertical bar charts.
- Bars are colored by HF org via `_company_color_map`.
- Same-company bars in the same chart get progressively lighter fills +
slightly reduced alpha so they remain distinguishable.
- Each bar has a thin darker outline ("corner line") tracing the rounded
corners so it stands out from the white background.
- Values are drawn inside the bar (white) with rounded corners.
- Hover shows the full model id and the metric value.
"""
color_map = _company_color_map(list(model_ids))
fills, borders, text_alphas = _company_shade_assignments(
list(model_ids), color_map
)
text_in = [value_fmt.format(v) for v in values]
text_colors = [f"rgba(255,255,255,{a:.3f})" for a in text_alphas]
# Use integer x-positions so duplicate display labels (e.g. two models that
# share the same truncated short name) don't collapse into a single stacked
# bar. Tick labels are mapped back onto these positions below.
x_positions = list(range(len(labels)))
fig = go.Figure(
go.Bar(
x=x_positions,
y=values,
customdata=list(model_ids),
marker=dict(
color=fills,
line=dict(color=borders, width=1.6),
cornerradius=8,
),
text=text_in,
textposition="inside",
insidetextanchor="middle",
textfont=dict(
color=text_colors, size=13, family="Inter, Arial, sans-serif"
),
hovertemplate=(
"<b>%{customdata}</b><br>"
f"{hover_metric_label}: {hover_value_fmt}<extra></extra>"
),
)
)
max_v = max([float(v) for v in values if v is not None] or [1.0])
fig.update_layout(
title=dict(
text=(
f"<b>{title}</b>"
f"<br><span style='font-size:12px;color:#6b7280'>{subtitle}</span>"
),
x=0.0,
xanchor="left",
font=dict(size=18),
),
xaxis=dict(
tickmode="array",
tickvals=x_positions,
ticktext=list(labels),
tickangle=-40,
automargin=True,
showgrid=False,
showline=False,
),
yaxis=dict(
visible=False,
range=[0, max_v * 1.18],
),
template=_TEMPLATE,
height=_FIG_HEIGHT,
margin=dict(l=20, r=20, t=80, b=120),
showlegend=False,
bargap=0.25,
plot_bgcolor="white",
)
return fig
def plot_avg_wer_bars(
df: pd.DataFrame,
top_n: int = 10,
) -> go.Figure:
"""
Ranked vertical bars of average WER (``avg_wer``) as a percent.
Top models sorted ascending (lower WER is better); value inside the bar,
rounded corners, one color per HF org.
"""
if df.empty or "avg_wer" not in df.columns:
return _empty_fig("No leaderboard WER data yet.")
d = df[["model_id", "avg_wer"]].dropna(subset=["avg_wer"]).copy()
if d.empty:
return _empty_fig("No WER values computed.")
d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce")
d = d.dropna(subset=["avg_wer"])
d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v)))
d = d.dropna(subset=["wer_pct"])
d = d.sort_values("wer_pct", ascending=True).head(max(1, int(top_n)))
labels = [_short_model_label(m) for m in d["model_id"].tolist()]
return _plot_company_bars(
labels=labels,
values=[float(v) for v in d["wer_pct"].tolist()],
model_ids=d["model_id"].tolist(),
title="WER",
subtitle="Average WER (%) · Lower is better",
value_fmt="{:.1f}",
hover_value_fmt="%{y:.2f}%",
hover_metric_label="Average WER",
)
def plot_speed_bars(
df: pd.DataFrame,
top_n: int = 10,
) -> go.Figure:
"""
Ranked vertical bars of throughput (`eval_rtf`, × realtime).
Companion to `plot_avg_wer_bars`: same color-per-company convention,
same in-bar value labels, rounded corners.
"""
if df.empty or "eval_rtf" not in df.columns:
return _empty_fig("No RTF measurements yet.")
d = df[["model_id", "eval_rtf"]].copy()
d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce")
d = d.dropna(subset=["eval_rtf"])
d = d[d["eval_rtf"] > 0]
if d.empty:
return _empty_fig("No RTF measurements yet.")
d = d.sort_values("eval_rtf", ascending=False).head(max(1, int(top_n)))
labels = [_short_model_label(m) for m in d["model_id"].tolist()]
values = [float(v) for v in d["eval_rtf"].tolist()]
# Show ints when large, one decimal when small (matches AA "314" / "0.7" feel).
fmt = "{:.0f}" if max(values) >= 10 else "{:.1f}"
return _plot_company_bars(
labels=labels,
values=values,
model_ids=d["model_id"].tolist(),
title="Speed",
subtitle="RTFx (audio sec / inference sec) · Higher is better",
value_fmt=fmt,
hover_value_fmt="%{y:.2f}×",
hover_metric_label="RTFx",
)
def plot_leaderboard_score_bars(
df: pd.DataFrame,
top_n: int = 40,
title: str | None = None,
) -> go.Figure:
"""
Ranked horizontal bar chart of average WER (percent).
"""
if df.empty or "avg_wer" not in df.columns:
return _empty_fig("No leaderboard WER data yet.")
cols = ["model_id", "avg_wer"]
if "pareto_layer" in df.columns:
cols.append("pareto_layer")
d = df[cols].dropna(subset=["avg_wer"]).copy()
if d.empty:
return _empty_fig("No WER values computed.")
n = max(1, int(top_n))
d = _sort_df_leaderboard_order(d).head(n)
labels = d["model_id"].str.split("/").str[-1].str[:42]
scores = d["avg_wer"].astype(float).apply(lambda v: _wer_pct(v) or 0.0)
fig = go.Figure(
go.Bar(
x=scores,
y=labels,
orientation="h",
marker=dict(
color=scores,
colorscale="RdYlGn_r",
showscale=True,
colorbar=dict(title="WER (%)"),
),
text=[f"{s:.1f}" for s in scores],
textposition="outside",
hovertemplate="<b>%{y}</b><br>Average WER: %{x:.2f}%<extra></extra>",
)
)
fig.update_layout(
title=dict(
text=title or f"Average WER (top {len(d)} models)",
x=0.5,
xanchor="center",
),
xaxis=dict(title="WER (%)", rangemode="tozero"),
yaxis=dict(autorange="reversed"),
template=_TEMPLATE,
height=int(min(920, max(_FIG_HEIGHT, 100 + 22 * len(d)))),
margin=dict(l=200, r=100, t=60, b=50),
)
return fig
def _radar_absolute_value(row: pd.Series, axis: str) -> tuple[float, str]:
"""Map a leaderboard row to an absolute 0–1 radar coordinate (higher = better)."""
if axis == "WER":
wer = pd.to_numeric(row.get("avg_wer"), errors="coerce")
if not np.isfinite(wer) or float(wer) < 0:
return 0.0, "Average WER: N/A"
w = float(wer)
pct = w * 100.0
v = float(max(0.0, min(1.0, 1.0 / (1.0 + w))))
return v, f"Average WER: {pct:.2f}% → strength {v:.3f}"
if axis == "Speed":
rtf = pd.to_numeric(row.get("eval_rtf"), errors="coerce")
if not np.isfinite(rtf) or float(rtf) < 0:
return 0.0, "RTFx: N/A"
r = float(rtf)
v = float(max(0.0, min(1.0, r / (1.0 + r))))
return v, f"RTFx: {r:.4f}× → speed {v:.3f} (RTFx/(1+RTFx))"
if axis == "Compactness":
pm = pd.to_numeric(row.get("num_params_m"), errors="coerce")
if not np.isfinite(pm) or float(pm) < 0:
return 0.0, "Parameters: N/A"
p = float(pm)
v = float(max(0.0, min(1.0, 1.0 / (1.0 + p / 1000.0))))
return v, f"Parameters: {p:.2f} M → compactness {v:.3f}"
return 0.0, f"{axis}: N/A"
_RADAR_ABSOLUTE_AXES: tuple[str, ...] = ("WER", "Speed", "Compactness")
def plot_robustness_radar(
df: pd.DataFrame,
model_ids: Sequence[str],
title: str = "Robustness radar (absolute 0–1; outward is better)",
) -> go.Figure:
"""
Three-axis radar with **absolute** coordinates in [0, 1] (not relative to other models).
WER strength = 1 / (1 + avg_wer); Speed = RTFx / (1 + RTFx);
Compactness = 1 / (1 + num_params_m / 1000). Missing values map to 0.
"""
if df.empty:
return _empty_fig("No leaderboard data yet.")
d = df.copy().reset_index(drop=True)
labels = list(_RADAR_ABSOLUTE_AXES)
labels_closed = labels + [labels[0]]
selected = [m for m in (model_ids or []) if m in set(d["model_id"])]
if not selected:
selected = d["model_id"].tolist()[: min(5, len(d))]
fig = go.Figure()
palette = qualitative.Plotly * 3
for i, mid in enumerate(selected):
rows = d.index[d["model_id"] == mid].tolist()
if not rows:
continue
row = d.iloc[rows[0]]
r_vals: list[float] = []
raw_text: list[str] = []
for ax in labels:
v, tip = _radar_absolute_value(row, ax)
r_vals.append(v)
raw_text.append(tip)
r_closed = r_vals + [r_vals[0]]
raw_closed = raw_text + [raw_text[0]]
short = mid.split("/")[-1][:28]
color = palette[i % len(palette)]
fig.add_trace(
go.Scatterpolar(
r=r_closed,
theta=labels_closed,
customdata=raw_closed,
fill="none",
name=short,
line=dict(width=2.2, color=color),
marker=dict(size=6, color=color),
opacity=1.0,
hovertemplate=(
f"<b>{short}</b><br>"
"%{theta}: %{r:.2f}<br>"
"%{customdata}<extra></extra>"
),
)
)
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1],
showticklabels=True,
tickformat=".2f",
tickvals=[0, 0.25, 0.5, 0.75, 1.0],
),
angularaxis=dict(direction="clockwise", rotation=90),
),
title=dict(text=title, x=0.5, xanchor="center"),
template=_TEMPLATE,
height=_FIG_HEIGHT,
margin=dict(l=50, r=120, t=60, b=50),
showlegend=True,
**_plotly_legend_config(),
)
return fig
def plot_radar_compare(
df: pd.DataFrame,
model_ids: Sequence[str],
metric_keys: Sequence[str] | None = None,
title: str = "Robustness profile (relative strength; outward is better)",
) -> go.Figure:
"""Deprecated alias for older UI bindings; forwards to `plot_robustness_radar`."""
return plot_robustness_radar(df, model_ids, title=title)
def plot_compare_models_across_scenarios(
df: pd.DataFrame,
model_ids: Sequence[str],
metric_keys: Sequence[str],
title: str = "WER by model (one trace per scenario; use legend to toggle)",
) -> go.Figure:
"""
Line chart: x = model (sorted by Average WER), y = WER, one trace per scenario.
Scaling to more models now just extends the x axis instead of adding more traces,
which keeps the legend small (≤ #scenarios) and makes cross-scenario comparisons
for the same model trivial (read a vertical slice).
`model_ids` filters which models appear on the x axis; empty / None → all rows.
"""
metric_keys = [k for k in metric_keys if k in df.columns]
if df.empty or not metric_keys:
return _empty_fig("Not enough data.")
selected = [m for m in (model_ids or []) if m]
sub = df[df["model_id"].isin(selected)].copy() if selected else df.copy()
if sub.empty:
return _empty_fig("No matching models.")
# Sort models by Average WER (lower is better).
if "avg_wer" in sub.columns:
sub = sub.sort_values("avg_wer", ascending=True, na_position="last")
sub = sub.reset_index(drop=True)
x_labels = sub["model_id"].str.split("/").str[-1].str[:32].tolist()
full_ids = sub["model_id"].tolist()
fig = go.Figure()
palette = qualitative.Plotly * 3
for i, k in enumerate(metric_keys):
m = metric_by_key(k)
scen_label = m.short if m else k
ys = [
_wer_pct(v) if pd.notna(v) else None for v in sub[k].tolist()
]
fig.add_trace(
go.Scatter(
x=x_labels,
y=ys,
mode="lines+markers",
name=scen_label,
customdata=full_ids,
line=dict(width=2, color=palette[i % len(palette)]),
marker=dict(size=8),
hovertemplate=(
f"<b>%{{customdata}}</b><br>{scen_label}: %{{y:.2f}}%<extra></extra>"
),
)
)
n_models = len(x_labels)
tickangle = -45 if n_models > 8 else -25
fig.update_layout(
title=dict(text=title, x=0.5, xanchor="center"),
xaxis=dict(title="Model (sorted by Average WER)", tickangle=tickangle, automargin=True),
yaxis=dict(title="WER (%) — lower is better", rangemode="tozero"),
template=_TEMPLATE,
height=_FIG_HEIGHT,
margin=dict(l=60, r=120, t=60, b=140),
**_plotly_legend_config(),
)
return fig
def plot_scenario_heatmap(
df: pd.DataFrame,
metric_keys: Sequence[str],
top_n: int = 30,
title: str | None = None,
) -> go.Figure:
"""
Models × scenarios heatmap of WER values (lower = greener).
This replaces the per-scenario line chart for the common case of *N≫few* models
and a small fixed set of conditions: every row is one model and every column is
one condition, so the eye can immediately spot:
- which models do badly on which conditions (hot rows / cells),
- which conditions are universally hard (whole columns red),
- and how the top of the leaderboard compares to the rest.
`metric_keys` selects which scenario columns to show (in given order). Models
are sorted by Average WER and capped at `top_n` rows so the chart
stays legible even with hundreds of leaderboard entries.
"""
metric_keys = [k for k in metric_keys if k in df.columns]
if df.empty or not metric_keys:
return _empty_fig("Not enough data for the heatmap.")
d = _sort_df_leaderboard_order(df.copy())
d = d.head(max(1, int(top_n))).reset_index(drop=True)
# Omit scenarios with no data in the visible rows (avoids sparse heatmaps).
keep_keys: list[str] = []
for k in metric_keys:
col = pd.to_numeric(d[k], errors="coerce")
if col.notna().any():
keep_keys.append(k)
if not keep_keys:
return _empty_fig("No WER values for the selected scenarios.")
order_index = {k: i for i, k in enumerate(HEATMAP_SCENARIO_KEYS)}
keep_keys = sorted(keep_keys, key=lambda k: order_index.get(k, len(order_index)))
x_labels = [heatmap_label_for_key(k) for k in keep_keys]
y_labels = d["model_id"].str.split("/").str[-1].str[:42].tolist()
z: list[list[float]] = []
cell_text: list[list[str]] = []
for _, row in d.iterrows():
z_row: list[float] = []
t_row: list[str] = []
for k in keep_keys:
v = pd.to_numeric(row.get(k), errors="coerce")
if pd.isna(v):
z_row.append(float("nan"))
t_row.append("")
else:
pct = _wer_pct(float(v))
if pct is None:
z_row.append(float("nan"))
t_row.append("")
else:
z_row.append(float(pct))
t_row.append(f"{pct:.1f}")
z.append(z_row)
cell_text.append(t_row)
flat_vals = [v for row in z for v in row if np.isfinite(v)]
if flat_vals:
raw_max = float(max(flat_vals))
if raw_max <= 25:
zmax = 25.0
elif raw_max <= 60:
zmax = float(np.ceil(raw_max / 5.0) * 5.0)
elif raw_max <= 100:
zmax = float(np.ceil(raw_max / 10.0) * 10.0)
else:
zmax = float(np.ceil(raw_max / 25.0) * 25.0)
else:
zmax = 100.0
fig = go.Figure(
data=go.Heatmap(
z=z,
x=x_labels,
y=y_labels,
text=cell_text,
texttemplate="%{text}",
textfont=dict(size=10, color="#1a1a1a"),
hovertemplate="<b>%{y}</b><br>%{x}: %{z:.2f}%<extra></extra>",
colorscale="RdYlGn_r",
zmin=0.0,
zmax=zmax,
xgap=2,
ygap=2,
colorbar=dict(title="WER (%)", tickformat=".0f"),
)
)
n_rows = max(1, len(y_labels))
# Dynamic height so 30 models gets ~460 px and 80 models gets a taller chart,
# without becoming uselessly tall.
height = int(min(900, max(_FIG_HEIGHT, 80 + 18 * n_rows)))
fig.update_layout(
title=dict(
text=title
or f"WER heatmap: top {n_rows} models × {len(keep_keys)} scenarios (by Average WER)",
x=0.5,
xanchor="center",
),
xaxis=dict(
title=None,
tickangle=-35,
side="top",
automargin=True,
tickfont=dict(size=10),
),
yaxis=dict(title="Model (sorted by Average WER)", automargin=True, autorange="reversed"),
template=_TEMPLATE,
height=height,
margin=dict(l=80, r=60, t=110, b=60),
)
# Scenario tick labels sit on top (side="top"); placing the axis title there too
# collides with the plot title, so render the "Scenario" label at the bottom.
fig.add_annotation(
text="Scenario",
xref="paper",
yref="paper",
x=0.5,
y=-0.02,
yanchor="top",
showarrow=False,
font=dict(size=13),
)
return fig
def plot_clean_vs_reverb_scatter(
df: pd.DataFrame,
title: str = "Near Field Speech versus Lab Simulated WER",
) -> go.Figure:
"""
Scatter: x = near-field speech WER, y = lab-simulated WER. Models often sit above y=x when simulation is harder.
"""
c1, c2 = "wer_anechoic_speech", "wer_lab_simulated"
if df.empty or c1 not in df.columns or c2 not in df.columns:
return _empty_fig("Need near-field speech and lab-simulated WER columns in the leaderboard.")
d = df[["model_id", c1, c2]].copy()
d[c1] = pd.to_numeric(d[c1], errors="coerce")
d[c2] = pd.to_numeric(d[c2], errors="coerce")
d = d.dropna(subset=[c1, c2])
if d.empty:
return _empty_fig("No complete near-field speech / lab-simulated WER pairs yet.")
fig = go.Figure()
x_pct = [_wer_pct(v) for v in d[c1]]
y_pct = [_wer_pct(v) for v in d[c2]]
mx = float(max((v for v in x_pct + y_pct if v is not None), default=0.0))
fig.add_trace(
go.Scatter(
x=[0, mx],
y=[0, mx],
mode="lines",
name="y = x (equal WER)",
line=dict(dash="dash", color="rgba(0,0,0,0.35)", width=2),
hoverinfo="skip",
)
)
short = d["model_id"].str.split("/").str[-1].str[:28]
fig.add_trace(
go.Scatter(
x=x_pct,
y=y_pct,
mode="markers",
text=short,
marker=dict(size=10, opacity=0.85),
hovertemplate=(
"<b>%{text}</b><br>Near field speech WER: %{x:.2f}%<br>"
"Lab simulated WER: %{y:.2f}%<extra></extra>"
),
)
)
fig.update_layout(
title=dict(text=title, x=0.5, xanchor="center"),
xaxis=dict(title="WER (%), Near Field Speech (lower is better)", rangemode="tozero"),
yaxis=dict(title="WER (%), Lab Simulated (lower is better)", rangemode="tozero"),
template=_TEMPLATE,
height=_FIG_HEIGHT,
margin=dict(l=60, r=40, t=60, b=60),
showlegend=True,
)
return fig
def plot_pareto_frontier(
df: pd.DataFrame,
title: str | None = None,
) -> go.Figure:
"""
Pareto front: X = Average WER (%, live scenarios), Y = RTFx (log scale).
Frontier models are highlighted with star markers + name labels and connected
by a dashed line; non-frontier models render as light dots.
"""
if df.empty:
return _empty_fig("No leaderboard data yet.")
need = {"model_id", "avg_wer", "eval_rtf"}
if not need.issubset(df.columns):
return _empty_fig("Missing Average WER or RTFx columns for the Pareto chart.")
d = df[["model_id", "avg_wer", "eval_rtf"]].copy()
d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce")
d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce")
d = d.dropna(subset=["avg_wer", "eval_rtf"])
d = d[d["eval_rtf"] > 0]
if d.empty:
return _empty_fig(
"No models with Average WER and RTFx. Re-run evaluations to populate timing."
)
d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v)))
d = d.dropna(subset=["wer_pct"]).reset_index(drop=True)
# Layer-1 frontier on (avg_wer, eval_rtf): lower WER + higher RTF is better.
fx_raw, fy = _pareto_efficient_frontier_wer_rtf(
d["avg_wer"].tolist(),
d["eval_rtf"].tolist(),
)
frontier_pairs = set(zip([round(w, 8) for w in fx_raw], [round(r, 8) for r in fy]))
def _is_frontier(row) -> bool:
return (round(float(row["avg_wer"]), 8), round(float(row["eval_rtf"]), 8)) in frontier_pairs
d["_frontier"] = d.apply(_is_frontier, axis=1)
frontier_df = d[d["_frontier"]].sort_values("wer_pct").reset_index(drop=True)
other_df = d[~d["_frontier"]]
accent = "#1f77ff" # cool blue (matches screenshot)
accent_light = "rgba(31, 119, 255, 0.30)"
fig = go.Figure()
if not other_df.empty:
fig.add_trace(
go.Scatter(
x=other_df["wer_pct"],
y=other_df["eval_rtf"],
mode="markers",
name="Other models",
marker=dict(size=9, color=accent_light, line=dict(width=0, color="white")),
text=other_df["model_id"],
hovertemplate=(
"<b>%{text}</b><br>"
"Average WER: %{x:.2f}%<br>"
"RTFx: %{y:.2f}×<extra></extra>"
),
)
)
if not frontier_df.empty:
fig.add_trace(
go.Scatter(
x=frontier_df["wer_pct"],
y=frontier_df["eval_rtf"],
mode="lines+markers+text",
name="Pareto frontier",
line=dict(color=accent, width=2, dash="dash"),
marker=dict(size=14, color=accent, symbol="star", line=dict(width=0, color="white")),
text=frontier_df["model_id"],
textposition="top center",
textfont=dict(size=11, color=accent),
cliponaxis=False,
hovertemplate=(
"<b>%{text}</b><br>"
"Average WER: %{x:.2f}%<br>"
"RTFx: %{y:.2f}×<extra></extra>"
),
)
)
ttl = title or "Pareto Front: Average WER vs RTFx"
rtf_min = float(d["eval_rtf"].min())
rtf_max = float(d["eval_rtf"].max())
log_y_lo = np.log10(max(rtf_min * 0.85, 1e-4))
log_y_hi = np.log10(rtf_max * 1.2)
fig.update_layout(
title=dict(text=ttl, x=0.0, xanchor="left", font=dict(size=18)),
xaxis=dict(
title="Average WER (lower is better)",
ticksuffix="",
zeroline=False,
),
template=_TEMPLATE,
height=_FIG_HEIGHT,
autosize=True,
margin=dict(l=70, r=40, t=70, b=60),
showlegend=True,
legend=dict(
yanchor="top",
y=0.98,
xanchor="right",
x=0.99,
bgcolor="rgba(255,255,255,0.7)",
borderwidth=0,
),
)
fig.update_yaxes(
title="RTFx (higher is better)",
type="log",
zeroline=False,
range=[log_y_lo, log_y_hi],
tickformat=".2g",
)
fig.update_xaxes(
title="Average WER (lower is better)",
zeroline=False,
range=[0, 100],
)
return fig
def plot_scenario_bar_summary(df: pd.DataFrame, top_n: int = 8) -> go.Figure:
"""Grouped bars: one trace per model; legend toggles models."""
if df.empty:
return _empty_fig("No leaderboard data yet.")
cols = [c for c in HEATMAP_SCENARIO_KEYS if c in df.columns]
if not cols:
return _empty_fig("No core WER columns.")
d = df.copy()
if "avg_wer" in d.columns:
d = d.sort_values("avg_wer", ascending=True, na_position="last").head(int(top_n))
else:
d["_avg"] = d[cols].mean(axis=1, skipna=True)
d = d.sort_values("_avg", ascending=True).head(int(top_n))
scenarios = cols
x_labels = [metric_by_key(c).short if metric_by_key(c) else c for c in scenarios]
fig = go.Figure()
palette = qualitative.Plotly * 3
for i, (_, row) in enumerate(d.iterrows()):
heights = [
_wer_pct(row[c]) if pd.notna(row[c]) else None for c in scenarios
]
short = row["model_id"].split("/")[-1][:22]
fig.add_trace(
go.Bar(
name=short,
x=x_labels,
y=heights,
marker_color=palette[i % len(palette)],
hovertemplate=f"<b>{short}</b><br>%{{x}}: %{{y:.2f}}%<extra></extra>",
)
)
fig.update_layout(
barmode="group",
title=dict(
text=f"WER by scenario: top {len(d)} models by Average WER (legend toggles models)",
x=0.5,
xanchor="center",
),
xaxis=dict(title="Scenario", tickangle=-20),
yaxis=dict(title="WER (%)", rangemode="tozero"),
template=_TEMPLATE,
height=_FIG_HEIGHT,
margin=dict(l=60, r=120, t=60, b=100),
**_plotly_legend_config(),
)
return fig
def plot_lines_by_rank(df: pd.DataFrame, metric_keys: Sequence[str], title: str | None = None) -> go.Figure:
"""Deprecated: use plot_clean_vs_reverb_scatter or plot_latency_vs_wer."""
return plot_clean_vs_reverb_scatter(df, title=title or "Clean vs reverberant WER")