""" Interactive Plotly charts for the Analysis tab (legend toggles traces, zoom, pan, hover). Radar = compare models on normalized axes. Middle chart = WER across scenarios (one trace per model). Bar chart = grouped WER by scenario for top models. """ from __future__ import annotations import threading from typing import Sequence import numpy as np import pandas as pd import plotly.graph_objects as go from plotly.colors import qualitative from metrics_config import ( HEATMAP_SCENARIO_KEYS, LIVE_SCENARIO_KEYS, SCENARIO_METRICS, heatmap_label_for_key, metric_by_key, ) # Mean WER over realistic SNR splits (Pareto X-axis and ranking). SNR_WER_KEYS: tuple[str, ...] = ( "wer_realistic_high_snr", "wer_realistic_mid_snr", "wer_realistic_low_snr", ) def _wer_pct(v) -> float | None: """Fractional WER (0–1) → percent for chart display.""" if v is None or (isinstance(v, float) and not np.isfinite(v)): return None try: f = float(v) except (TypeError, ValueError): return None if not np.isfinite(f): return None return f * 100.0 def _pareto_efficient_frontier_wer_rtf( wers: Sequence[float], rtfs: Sequence[float] ) -> tuple[list[float], list[float]]: """Layer-1 envelope: X = WER (lower better), Y = RTF (higher better).""" pts = [ (float(w), float(r)) for w, r in zip(wers, rtfs) if np.isfinite(w) and np.isfinite(r) ] if not pts: return [], [] frontier: list[tuple[float, float]] = [] for i, (wi, ri) in enumerate(pts): dominated = False for j, (wj, rj) in enumerate(pts): if i == j: continue if wj <= wi and rj >= ri and (wj < wi or rj > ri): dominated = True break if not dominated: frontier.append((wi, ri)) frontier.sort(key=lambda p: p[0]) return [p[0] for p in frontier], [p[1] for p in frontier] # Consistent height for Gradio Plot _FIG_HEIGHT = 460 _TEMPLATE = "plotly_white" # Weight-file size lookup for the "Memory" radar axis. Populated lazily on first render; # falls back to NaN if the hub call fails (e.g. gated repo without token). _model_size_cache: dict[str, float] = {} _model_size_lock = threading.Lock() _WEIGHT_EXTS = (".safetensors", ".bin", ".pt", ".pth", ".gguf", ".ckpt", ".msgpack", ".h5") def _cached_model_size_mb(model_id: str) -> float: with _model_size_lock: if model_id in _model_size_cache: return _model_size_cache[model_id] total = 0 try: from huggingface_hub import HfApi info = HfApi().model_info(model_id, files_metadata=True) for s in info.siblings or []: name = getattr(s, "rfilename", "") or "" size = getattr(s, "size", None) or 0 if name.lower().endswith(_WEIGHT_EXTS): total += int(size) except Exception: total = 0 mb = float(total) / (1024.0 * 1024.0) if total > 0 else float("nan") with _model_size_lock: _model_size_cache[model_id] = mb return mb def _coerce_wer(v) -> float | None: """Non-negative finite WER, or None if missing / invalid.""" if v is None or v == "": return None try: x = float(v) return x if np.isfinite(x) and x >= 0 else None except (TypeError, ValueError): return None def _coerce_rtf(v) -> float | None: """Non-negative finite RTF (audio sec / inference sec), or None if missing.""" if v is None or v == "": return None try: x = float(v) return x if np.isfinite(x) and x >= 0 else None except (TypeError, ValueError): return None def snr_avg_wer(row: dict) -> float | None: """Mean WER over high / mid / low SNR scenarios (skip missing).""" vals: list[float] = [] for k in SNR_WER_KEYS: w = _coerce_wer(row.get(k)) if w is not None: vals.append(w) if not vals: return None return float(sum(vals) / len(vals)) def compute_pareto_layers(rows: list[dict]) -> dict[str, int | None]: """ Pareto non-dominated layers over (snr_avg_wer ↓, eval_rtf ↑). Layer 1 = efficient frontier; peel and repeat. Missing WER or RTF → ``None``. """ pts: list[tuple[str, float, float]] = [] out: dict[str, int | None] = {} for r in rows: mid = (r.get("model_id") or "").strip() if not mid: continue w = snr_avg_wer(r) rtf = _coerce_rtf(r.get("eval_rtf")) if w is None or rtf is None: out[mid] = None else: pts.append((mid, w, rtf)) unassigned = list(pts) layer = 1 while unassigned: frontier: list[str] = [] for i, (mid_i, w_i, r_i) in enumerate(unassigned): dominated = False for j, (_, w_j, r_j) in enumerate(unassigned): if i == j: continue if w_j <= w_i and r_j >= r_i and (w_j < w_i or r_j > r_i): dominated = True break if not dominated: frontier.append(mid_i) frontier_set = set(frontier) for mid in frontier: out[mid] = layer unassigned = [p for p in unassigned if p[0] not in frontier_set] layer += 1 return out def _avg_wer_for_row(row: dict) -> float: """Mean WER over live scenarios; ``inf`` when nothing populated.""" vals: list[float] = [] for k in LIVE_SCENARIO_KEYS: w = _coerce_wer(row.get(k)) if w is not None: vals.append(w) if not vals: return float("inf") return sum(vals) / len(vals) def sort_leaderboard_rows_inplace(rows: list[dict]) -> None: """Sort rows by Average WER ascending (lower WER first).""" rows.sort(key=lambda r: (_avg_wer_for_row(r), (r.get("model_id") or ""))) def frontier_label(layer: int | None) -> str: return str(layer) if layer is not None else "NA" def _log_normalize(values: np.ndarray, *, inverse: bool, strength_floor: float = 0.1) -> np.ndarray: """ Log10-normalize an array across all models to a "strength" score (default display floor 0.1). * `inverse=True` -> lower raw value is better (WER, memory). Strength = 1 - norm. * `inverse=False` -> higher raw value is better (RTF). Strength = norm. Returns NaN for entries whose raw value is missing or non-positive. """ v = np.asarray(values, dtype=float) out = np.full_like(v, np.nan, dtype=float) mask = np.isfinite(v) & (v > 0) if not mask.any(): return out logv = np.log10(np.maximum(v[mask], 1e-12)) lo = float(np.min(logv)) hi = float(np.max(logv)) if hi - lo < 1e-9: out[mask] = 0.5 else: out[mask] = (logv - lo) / (hi - lo) if inverse: out[mask] = 1.0 - out[mask] fin = np.isfinite(out) out[fin] = np.clip(np.maximum(out[fin], strength_floor), strength_floor, 1.0) return out # Axis spec: (strength_label, primary col, fallback col, inverse, raw_display_name, raw_unit) # strength_label: chart axis label; always phrased so "farther = better". # inverse : True if *lower* raw value is better (e.g. WER, memory); False if higher is better (RTF). # raw_display_name: human-readable name of the underlying metric, shown in the hover tooltip. # raw_unit : optional unit suffix for the raw value in the tooltip. # # Live far-field scenario axes + Speed (RTF) + Compactness (parameters / Hub size). _RADAR_AXES: tuple[tuple[str, str, str | None, bool, str, str], ...] = ( ("Near Field Speech", "wer_anechoic_speech", None, True, "Near field speech WER", ""), ("Lab Measured", "wer_lab_measured", None, True, "Lab measured WER", ""), ("Lab Simulated", "wer_lab_simulated", None, True, "Lab simulated WER", ""), ("High SNR", "wer_realistic_high_snr", None, True, "High SNR WER", ""), ("Mid SNR", "wer_realistic_mid_snr", None, True, "Mid SNR WER", ""), ("Low SNR", "wer_realistic_low_snr", None, True, "Low SNR WER", ""), ("Moving Sources", "wer_moving_sources", None, True, "Moving Sources WER", ""), ("Speed", "eval_rtf", None, False, "RTFx", "× realtime"), # Compactness: smaller model -> more compact -> farther from centre. # `inverse=True` makes `_log_normalize` emit `1 - norm`, so the smallest # observed model maps to strength = 1.0 (radar edge) and the largest # maps to 0.0 (centre). # # Primary signal: parameter count (millions), which is recorded for every # backend by `attach_params` and is much more reliable than the Hub # weight-file size lookup (which can return NaN for gated/private repos # or repos that ship weights via LFS pointers without size metadata). # `model_size_mb` is the fallback for legacy rows that pre-date param # tracking. ("Compactness", "num_params_m", "model_size_mb", True, "Parameters", " M"), ) # Hover-label override when a radar axis falls back from `primary` to # `fallback`. Without this the tooltip would still claim the value is in the # primary metric (e.g. "Parameters: 1500.0000 M") when in fact the rendered # value came from the fallback column (`Model size: 1500.0000 MB`). _RADAR_FALLBACK_DISPLAY: dict[str, tuple[str, str]] = { "num_params_m": ("Model size", " MB"), } def _empty_fig(message: str) -> go.Figure: fig = go.Figure() fig.add_annotation( text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14), ) fig.update_layout( template=_TEMPLATE, height=400, margin=dict(l=40, r=40, t=40, b=40), xaxis=dict(visible=False), yaxis=dict(visible=False), ) return fig def _plotly_legend_config() -> dict: """Legend click toggles traces; double-click isolates.""" return dict( legend=dict( orientation="v", yanchor="top", y=1, xanchor="left", x=1.02, font=dict(size=10), ), hovermode="closest", ) def _raw_to_analytics_df(raw: list[dict]) -> pd.DataFrame: """ Parse leaderboard rows; coerce scenario WER + timing columns to float (NaN if missing). Computes: * ``avg_wer``: mean of live scenario WERs present for that row (reference). * ``model_size_mb``: total weight-file size fetched from the Hub (cached). """ from init import normalize_legacy_csv_row if not raw: return pd.DataFrame() rows = [normalize_legacy_csv_row(dict(r)) for r in raw] df = pd.DataFrame(rows) if "model_id" not in df.columns: return pd.DataFrame() for m in SCENARIO_METRICS: k = m.key if k not in df.columns: df[k] = np.nan else: df[k] = pd.to_numeric(df[k], errors="coerce") for k in ("eval_wall_time_s", "eval_rtf", "eval_audio_seconds", "num_params"): if k not in df.columns: df[k] = np.nan else: df[k] = pd.to_numeric(df[k], errors="coerce") # Millions of parameters; compact display. df["num_params_m"] = df["num_params"] / 1e6 live_cols = [c for c in LIVE_SCENARIO_KEYS if c in df.columns] df["avg_wer"] = df[live_cols].mean(axis=1, skipna=True) if live_cols else np.nan layers = compute_pareto_layers(rows) df["pareto_layer"] = [ layers.get((r.get("model_id") or "").strip()) for r in rows ] df["snr_avg_wer"] = [snr_avg_wer(r) for r in rows] df["model_size_mb"] = df["model_id"].apply(_cached_model_size_mb) return df def _sort_df_leaderboard_order(df: pd.DataFrame) -> pd.DataFrame: """Average WER ascending (lower is better).""" if df.empty: return df d = df.copy() if "avg_wer" in d.columns: return d.sort_values("avg_wer", ascending=True, na_position="last") return d def available_metric_keys(df: pd.DataFrame) -> list[str]: """Scenario keys that have at least one non-NaN value.""" keys = [] for m in SCENARIO_METRICS: if m.key in df.columns and df[m.key].notna().any(): keys.append(m.key) return keys # Brand-aligned colors per HF org / company. Lower-cased org prefix → hex. # Unknown orgs fall back to a deterministic palette so that the same company # keeps the same color across the Intelligence / Speed charts. _COMPANY_COLORS: dict[str, str] = { "nvidia": "#76B900", # NVIDIA green "openai": "#111111", # OpenAI black "openai-community": "#111111", "qwen": "#615CED", # Qwen purple "ibm-granite": "#0F62FE", # IBM blue "ibm": "#0F62FE", "coherelabs": "#39594D", # Cohere muted green "cohere": "#39594D", "efficient-speech": "#F2994A", "usefulsensors": "#22A7F0", "google": "#4285F4", # Google blue "google-t5": "#4285F4", "meta-llama": "#1877F2", # Meta blue "facebook": "#1877F2", "microsoft": "#00A4EF", "deepseek-ai": "#1D6FE0", "deepseek": "#1D6FE0", "moonshotai": "#111111", "mistralai": "#FF6F00", "anthropic": "#CC9B73", "x-ai": "#000000", "xai": "#000000", "zhipuai": "#1F8AFF", "zai-org": "#1F8AFF", } _FALLBACK_PALETTE: tuple[str, ...] = ( "#1F77B4", "#FF7F0E", "#2CA02C", "#D62728", "#9467BD", "#8C564B", "#E377C2", "#7F7F7F", "#BCBD22", "#17BECF", ) def _company_key(model_id: str) -> str: """Normalized company key from a HF model id (`org/name` → `org`).""" mid = (model_id or "").strip() if "/" in mid: return mid.split("/", 1)[0].lower() return mid.lower() def _company_color_map(model_ids: Sequence[str]) -> dict[str, str]: """Stable model_id → hex color mapping; same company shares a color.""" out: dict[str, str] = {} fallback_assign: dict[str, str] = {} unknown_companies: list[str] = [] for mid in model_ids: comp = _company_key(mid) if comp in _COMPANY_COLORS: out[mid] = _COMPANY_COLORS[comp] else: if comp not in fallback_assign: unknown_companies.append(comp) out[mid] = "" # filled in below # Deterministic palette assignment for unknown orgs (sorted alphabetically). for i, comp in enumerate(sorted(set(unknown_companies))): fallback_assign[comp] = _FALLBACK_PALETTE[i % len(_FALLBACK_PALETTE)] for mid in model_ids: if not out[mid]: out[mid] = fallback_assign[_company_key(mid)] return out def _short_model_label(model_id: str, max_len: int = 26) -> str: name = (model_id or "").split("/", 1)[-1] return name if len(name) <= max_len else name[: max_len - 1] + "…" def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]: h = hex_color.lstrip("#") return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) def _mix(c: int, target: int, t: float) -> int: return int(round(c + (target - c) * max(0.0, min(1.0, t)))) def _lighten(hex_color: str, t: float) -> str: """t in [0,1] -> mix toward white (t=0 returns the original color).""" r, g, b = _hex_to_rgb(hex_color) return "#{:02X}{:02X}{:02X}".format(_mix(r, 255, t), _mix(g, 255, t), _mix(b, 255, t)) def _darken(hex_color: str, t: float) -> str: """t in [0,1] -> mix toward black (t=0 returns the original color).""" r, g, b = _hex_to_rgb(hex_color) return "#{:02X}{:02X}{:02X}".format(_mix(r, 0, t), _mix(g, 0, t), _mix(b, 0, t)) def _to_rgba(hex_color: str, alpha: float) -> str: r, g, b = _hex_to_rgb(hex_color) return f"rgba({r},{g},{b},{alpha:.3f})" def _company_shade_assignments( model_ids: Sequence[str], color_map: dict[str, str], ) -> tuple[list[str], list[str], list[float]]: """ For bars whose company appears multiple times in this chart, vary the luminance + opacity so adjacent same-color bars remain distinguishable. Returns parallel lists (one entry per model_id): * fill rgba color (with per-bar alpha baked in) * border hex color (a darker shade of the base brand color) * text opacity (kept ≥ 0.85 so in-bar labels stay legible) """ counts: dict[str, int] = {} for mid in model_ids: counts[_company_key(mid)] = counts.get(_company_key(mid), 0) + 1 seen: dict[str, int] = {} fills: list[str] = [] borders: list[str] = [] text_alphas: list[float] = [] for mid in model_ids: base = color_map[mid] comp = _company_key(mid) n = counts[comp] idx = seen.get(comp, 0) seen[comp] = idx + 1 if n > 1: # First instance keeps base color; subsequent ones lighten ~12% per step # (capped at ~45%) and reduce alpha slightly so the boundary is obvious # even when two same-company bars sit next to each other. lighten_t = min(0.45, 0.12 * idx) alpha = max(0.65, 1.0 - 0.08 * idx) shaded = _lighten(base, lighten_t) fills.append(_to_rgba(shaded, alpha)) borders.append(_darken(base, 0.35)) text_alphas.append(max(0.9, 1.0 - 0.05 * idx)) else: fills.append(_to_rgba(base, 1.0)) borders.append(_darken(base, 0.35)) text_alphas.append(1.0) return fills, borders, text_alphas def _plot_company_bars( labels: Sequence[str], values: Sequence[float], model_ids: Sequence[str], *, title: str, subtitle: str, value_fmt: str, hover_value_fmt: str, hover_metric_label: str, ) -> go.Figure: """ Shared renderer for the Intelligence / Speed vertical bar charts. - Bars are colored by HF org via `_company_color_map`. - Same-company bars in the same chart get progressively lighter fills + slightly reduced alpha so they remain distinguishable. - Each bar has a thin darker outline ("corner line") tracing the rounded corners so it stands out from the white background. - Values are drawn inside the bar (white) with rounded corners. - Hover shows the full model id and the metric value. """ color_map = _company_color_map(list(model_ids)) fills, borders, text_alphas = _company_shade_assignments( list(model_ids), color_map ) text_in = [value_fmt.format(v) for v in values] text_colors = [f"rgba(255,255,255,{a:.3f})" for a in text_alphas] # Use integer x-positions so duplicate display labels (e.g. two models that # share the same truncated short name) don't collapse into a single stacked # bar. Tick labels are mapped back onto these positions below. x_positions = list(range(len(labels))) fig = go.Figure( go.Bar( x=x_positions, y=values, customdata=list(model_ids), marker=dict( color=fills, line=dict(color=borders, width=1.6), cornerradius=8, ), text=text_in, textposition="inside", insidetextanchor="middle", textfont=dict( color=text_colors, size=13, family="Inter, Arial, sans-serif" ), hovertemplate=( "%{customdata}
" f"{hover_metric_label}: {hover_value_fmt}" ), ) ) max_v = max([float(v) for v in values if v is not None] or [1.0]) fig.update_layout( title=dict( text=( f"{title}" f"
{subtitle}" ), x=0.0, xanchor="left", font=dict(size=18), ), xaxis=dict( tickmode="array", tickvals=x_positions, ticktext=list(labels), tickangle=-40, automargin=True, showgrid=False, showline=False, ), yaxis=dict( visible=False, range=[0, max_v * 1.18], ), template=_TEMPLATE, height=_FIG_HEIGHT, margin=dict(l=20, r=20, t=80, b=120), showlegend=False, bargap=0.25, plot_bgcolor="white", ) return fig def plot_avg_wer_bars( df: pd.DataFrame, top_n: int = 10, ) -> go.Figure: """ Ranked vertical bars of average WER (``avg_wer``) as a percent. Top models sorted ascending (lower WER is better); value inside the bar, rounded corners, one color per HF org. """ if df.empty or "avg_wer" not in df.columns: return _empty_fig("No leaderboard WER data yet.") d = df[["model_id", "avg_wer"]].dropna(subset=["avg_wer"]).copy() if d.empty: return _empty_fig("No WER values computed.") d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce") d = d.dropna(subset=["avg_wer"]) d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v))) d = d.dropna(subset=["wer_pct"]) d = d.sort_values("wer_pct", ascending=True).head(max(1, int(top_n))) labels = [_short_model_label(m) for m in d["model_id"].tolist()] return _plot_company_bars( labels=labels, values=[float(v) for v in d["wer_pct"].tolist()], model_ids=d["model_id"].tolist(), title="WER", subtitle="Average WER (%) · Lower is better", value_fmt="{:.1f}", hover_value_fmt="%{y:.2f}%", hover_metric_label="Average WER", ) def plot_speed_bars( df: pd.DataFrame, top_n: int = 10, ) -> go.Figure: """ Ranked vertical bars of throughput (`eval_rtf`, × realtime). Companion to `plot_avg_wer_bars`: same color-per-company convention, same in-bar value labels, rounded corners. """ if df.empty or "eval_rtf" not in df.columns: return _empty_fig("No RTF measurements yet.") d = df[["model_id", "eval_rtf"]].copy() d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce") d = d.dropna(subset=["eval_rtf"]) d = d[d["eval_rtf"] > 0] if d.empty: return _empty_fig("No RTF measurements yet.") d = d.sort_values("eval_rtf", ascending=False).head(max(1, int(top_n))) labels = [_short_model_label(m) for m in d["model_id"].tolist()] values = [float(v) for v in d["eval_rtf"].tolist()] # Show ints when large, one decimal when small (matches AA "314" / "0.7" feel). fmt = "{:.0f}" if max(values) >= 10 else "{:.1f}" return _plot_company_bars( labels=labels, values=values, model_ids=d["model_id"].tolist(), title="Speed", subtitle="RTFx (audio sec / inference sec) · Higher is better", value_fmt=fmt, hover_value_fmt="%{y:.2f}×", hover_metric_label="RTFx", ) def plot_leaderboard_score_bars( df: pd.DataFrame, top_n: int = 40, title: str | None = None, ) -> go.Figure: """ Ranked horizontal bar chart of average WER (percent). """ if df.empty or "avg_wer" not in df.columns: return _empty_fig("No leaderboard WER data yet.") cols = ["model_id", "avg_wer"] if "pareto_layer" in df.columns: cols.append("pareto_layer") d = df[cols].dropna(subset=["avg_wer"]).copy() if d.empty: return _empty_fig("No WER values computed.") n = max(1, int(top_n)) d = _sort_df_leaderboard_order(d).head(n) labels = d["model_id"].str.split("/").str[-1].str[:42] scores = d["avg_wer"].astype(float).apply(lambda v: _wer_pct(v) or 0.0) fig = go.Figure( go.Bar( x=scores, y=labels, orientation="h", marker=dict( color=scores, colorscale="RdYlGn_r", showscale=True, colorbar=dict(title="WER (%)"), ), text=[f"{s:.1f}" for s in scores], textposition="outside", hovertemplate="%{y}
Average WER: %{x:.2f}%", ) ) fig.update_layout( title=dict( text=title or f"Average WER (top {len(d)} models)", x=0.5, xanchor="center", ), xaxis=dict(title="WER (%)", rangemode="tozero"), yaxis=dict(autorange="reversed"), template=_TEMPLATE, height=int(min(920, max(_FIG_HEIGHT, 100 + 22 * len(d)))), margin=dict(l=200, r=100, t=60, b=50), ) return fig def _radar_absolute_value(row: pd.Series, axis: str) -> tuple[float, str]: """Map a leaderboard row to an absolute 0–1 radar coordinate (higher = better).""" if axis == "WER": wer = pd.to_numeric(row.get("avg_wer"), errors="coerce") if not np.isfinite(wer) or float(wer) < 0: return 0.0, "Average WER: N/A" w = float(wer) pct = w * 100.0 v = float(max(0.0, min(1.0, 1.0 / (1.0 + w)))) return v, f"Average WER: {pct:.2f}% → strength {v:.3f}" if axis == "Speed": rtf = pd.to_numeric(row.get("eval_rtf"), errors="coerce") if not np.isfinite(rtf) or float(rtf) < 0: return 0.0, "RTFx: N/A" r = float(rtf) v = float(max(0.0, min(1.0, r / (1.0 + r)))) return v, f"RTFx: {r:.4f}× → speed {v:.3f} (RTFx/(1+RTFx))" if axis == "Compactness": pm = pd.to_numeric(row.get("num_params_m"), errors="coerce") if not np.isfinite(pm) or float(pm) < 0: return 0.0, "Parameters: N/A" p = float(pm) v = float(max(0.0, min(1.0, 1.0 / (1.0 + p / 1000.0)))) return v, f"Parameters: {p:.2f} M → compactness {v:.3f}" return 0.0, f"{axis}: N/A" _RADAR_ABSOLUTE_AXES: tuple[str, ...] = ("WER", "Speed", "Compactness") def plot_robustness_radar( df: pd.DataFrame, model_ids: Sequence[str], title: str = "Robustness radar (absolute 0–1; outward is better)", ) -> go.Figure: """ Three-axis radar with **absolute** coordinates in [0, 1] (not relative to other models). WER strength = 1 / (1 + avg_wer); Speed = RTFx / (1 + RTFx); Compactness = 1 / (1 + num_params_m / 1000). Missing values map to 0. """ if df.empty: return _empty_fig("No leaderboard data yet.") d = df.copy().reset_index(drop=True) labels = list(_RADAR_ABSOLUTE_AXES) labels_closed = labels + [labels[0]] selected = [m for m in (model_ids or []) if m in set(d["model_id"])] if not selected: selected = d["model_id"].tolist()[: min(5, len(d))] fig = go.Figure() palette = qualitative.Plotly * 3 for i, mid in enumerate(selected): rows = d.index[d["model_id"] == mid].tolist() if not rows: continue row = d.iloc[rows[0]] r_vals: list[float] = [] raw_text: list[str] = [] for ax in labels: v, tip = _radar_absolute_value(row, ax) r_vals.append(v) raw_text.append(tip) r_closed = r_vals + [r_vals[0]] raw_closed = raw_text + [raw_text[0]] short = mid.split("/")[-1][:28] color = palette[i % len(palette)] fig.add_trace( go.Scatterpolar( r=r_closed, theta=labels_closed, customdata=raw_closed, fill="none", name=short, line=dict(width=2.2, color=color), marker=dict(size=6, color=color), opacity=1.0, hovertemplate=( f"{short}
" "%{theta}: %{r:.2f}
" "%{customdata}" ), ) ) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 1], showticklabels=True, tickformat=".2f", tickvals=[0, 0.25, 0.5, 0.75, 1.0], ), angularaxis=dict(direction="clockwise", rotation=90), ), title=dict(text=title, x=0.5, xanchor="center"), template=_TEMPLATE, height=_FIG_HEIGHT, margin=dict(l=50, r=120, t=60, b=50), showlegend=True, **_plotly_legend_config(), ) return fig def plot_radar_compare( df: pd.DataFrame, model_ids: Sequence[str], metric_keys: Sequence[str] | None = None, title: str = "Robustness profile (relative strength; outward is better)", ) -> go.Figure: """Deprecated alias for older UI bindings; forwards to `plot_robustness_radar`.""" return plot_robustness_radar(df, model_ids, title=title) def plot_compare_models_across_scenarios( df: pd.DataFrame, model_ids: Sequence[str], metric_keys: Sequence[str], title: str = "WER by model (one trace per scenario; use legend to toggle)", ) -> go.Figure: """ Line chart: x = model (sorted by Average WER), y = WER, one trace per scenario. Scaling to more models now just extends the x axis instead of adding more traces, which keeps the legend small (≤ #scenarios) and makes cross-scenario comparisons for the same model trivial (read a vertical slice). `model_ids` filters which models appear on the x axis; empty / None → all rows. """ metric_keys = [k for k in metric_keys if k in df.columns] if df.empty or not metric_keys: return _empty_fig("Not enough data.") selected = [m for m in (model_ids or []) if m] sub = df[df["model_id"].isin(selected)].copy() if selected else df.copy() if sub.empty: return _empty_fig("No matching models.") # Sort models by Average WER (lower is better). if "avg_wer" in sub.columns: sub = sub.sort_values("avg_wer", ascending=True, na_position="last") sub = sub.reset_index(drop=True) x_labels = sub["model_id"].str.split("/").str[-1].str[:32].tolist() full_ids = sub["model_id"].tolist() fig = go.Figure() palette = qualitative.Plotly * 3 for i, k in enumerate(metric_keys): m = metric_by_key(k) scen_label = m.short if m else k ys = [ _wer_pct(v) if pd.notna(v) else None for v in sub[k].tolist() ] fig.add_trace( go.Scatter( x=x_labels, y=ys, mode="lines+markers", name=scen_label, customdata=full_ids, line=dict(width=2, color=palette[i % len(palette)]), marker=dict(size=8), hovertemplate=( f"%{{customdata}}
{scen_label}: %{{y:.2f}}%" ), ) ) n_models = len(x_labels) tickangle = -45 if n_models > 8 else -25 fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), xaxis=dict(title="Model (sorted by Average WER)", tickangle=tickangle, automargin=True), yaxis=dict(title="WER (%) — lower is better", rangemode="tozero"), template=_TEMPLATE, height=_FIG_HEIGHT, margin=dict(l=60, r=120, t=60, b=140), **_plotly_legend_config(), ) return fig def plot_scenario_heatmap( df: pd.DataFrame, metric_keys: Sequence[str], top_n: int = 30, title: str | None = None, ) -> go.Figure: """ Models × scenarios heatmap of WER values (lower = greener). This replaces the per-scenario line chart for the common case of *N≫few* models and a small fixed set of conditions: every row is one model and every column is one condition, so the eye can immediately spot: - which models do badly on which conditions (hot rows / cells), - which conditions are universally hard (whole columns red), - and how the top of the leaderboard compares to the rest. `metric_keys` selects which scenario columns to show (in given order). Models are sorted by Average WER and capped at `top_n` rows so the chart stays legible even with hundreds of leaderboard entries. """ metric_keys = [k for k in metric_keys if k in df.columns] if df.empty or not metric_keys: return _empty_fig("Not enough data for the heatmap.") d = _sort_df_leaderboard_order(df.copy()) d = d.head(max(1, int(top_n))).reset_index(drop=True) # Omit scenarios with no data in the visible rows (avoids sparse heatmaps). keep_keys: list[str] = [] for k in metric_keys: col = pd.to_numeric(d[k], errors="coerce") if col.notna().any(): keep_keys.append(k) if not keep_keys: return _empty_fig("No WER values for the selected scenarios.") order_index = {k: i for i, k in enumerate(HEATMAP_SCENARIO_KEYS)} keep_keys = sorted(keep_keys, key=lambda k: order_index.get(k, len(order_index))) x_labels = [heatmap_label_for_key(k) for k in keep_keys] y_labels = d["model_id"].str.split("/").str[-1].str[:42].tolist() z: list[list[float]] = [] cell_text: list[list[str]] = [] for _, row in d.iterrows(): z_row: list[float] = [] t_row: list[str] = [] for k in keep_keys: v = pd.to_numeric(row.get(k), errors="coerce") if pd.isna(v): z_row.append(float("nan")) t_row.append("") else: pct = _wer_pct(float(v)) if pct is None: z_row.append(float("nan")) t_row.append("") else: z_row.append(float(pct)) t_row.append(f"{pct:.1f}") z.append(z_row) cell_text.append(t_row) flat_vals = [v for row in z for v in row if np.isfinite(v)] if flat_vals: raw_max = float(max(flat_vals)) if raw_max <= 25: zmax = 25.0 elif raw_max <= 60: zmax = float(np.ceil(raw_max / 5.0) * 5.0) elif raw_max <= 100: zmax = float(np.ceil(raw_max / 10.0) * 10.0) else: zmax = float(np.ceil(raw_max / 25.0) * 25.0) else: zmax = 100.0 fig = go.Figure( data=go.Heatmap( z=z, x=x_labels, y=y_labels, text=cell_text, texttemplate="%{text}", textfont=dict(size=10, color="#1a1a1a"), hovertemplate="%{y}
%{x}: %{z:.2f}%", colorscale="RdYlGn_r", zmin=0.0, zmax=zmax, xgap=2, ygap=2, colorbar=dict(title="WER (%)", tickformat=".0f"), ) ) n_rows = max(1, len(y_labels)) # Dynamic height so 30 models gets ~460 px and 80 models gets a taller chart, # without becoming uselessly tall. height = int(min(900, max(_FIG_HEIGHT, 80 + 18 * n_rows))) fig.update_layout( title=dict( text=title or f"WER heatmap: top {n_rows} models × {len(keep_keys)} scenarios (by Average WER)", x=0.5, xanchor="center", ), xaxis=dict( title=None, tickangle=-35, side="top", automargin=True, tickfont=dict(size=10), ), yaxis=dict(title="Model (sorted by Average WER)", automargin=True, autorange="reversed"), template=_TEMPLATE, height=height, margin=dict(l=80, r=60, t=110, b=60), ) # Scenario tick labels sit on top (side="top"); placing the axis title there too # collides with the plot title, so render the "Scenario" label at the bottom. fig.add_annotation( text="Scenario", xref="paper", yref="paper", x=0.5, y=-0.02, yanchor="top", showarrow=False, font=dict(size=13), ) return fig def plot_clean_vs_reverb_scatter( df: pd.DataFrame, title: str = "Near Field Speech versus Lab Simulated WER", ) -> go.Figure: """ Scatter: x = near-field speech WER, y = lab-simulated WER. Models often sit above y=x when simulation is harder. """ c1, c2 = "wer_anechoic_speech", "wer_lab_simulated" if df.empty or c1 not in df.columns or c2 not in df.columns: return _empty_fig("Need near-field speech and lab-simulated WER columns in the leaderboard.") d = df[["model_id", c1, c2]].copy() d[c1] = pd.to_numeric(d[c1], errors="coerce") d[c2] = pd.to_numeric(d[c2], errors="coerce") d = d.dropna(subset=[c1, c2]) if d.empty: return _empty_fig("No complete near-field speech / lab-simulated WER pairs yet.") fig = go.Figure() x_pct = [_wer_pct(v) for v in d[c1]] y_pct = [_wer_pct(v) for v in d[c2]] mx = float(max((v for v in x_pct + y_pct if v is not None), default=0.0)) fig.add_trace( go.Scatter( x=[0, mx], y=[0, mx], mode="lines", name="y = x (equal WER)", line=dict(dash="dash", color="rgba(0,0,0,0.35)", width=2), hoverinfo="skip", ) ) short = d["model_id"].str.split("/").str[-1].str[:28] fig.add_trace( go.Scatter( x=x_pct, y=y_pct, mode="markers", text=short, marker=dict(size=10, opacity=0.85), hovertemplate=( "%{text}
Near field speech WER: %{x:.2f}%
" "Lab simulated WER: %{y:.2f}%" ), ) ) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), xaxis=dict(title="WER (%), Near Field Speech (lower is better)", rangemode="tozero"), yaxis=dict(title="WER (%), Lab Simulated (lower is better)", rangemode="tozero"), template=_TEMPLATE, height=_FIG_HEIGHT, margin=dict(l=60, r=40, t=60, b=60), showlegend=True, ) return fig def plot_pareto_frontier( df: pd.DataFrame, title: str | None = None, ) -> go.Figure: """ Pareto front: X = Average WER (%, live scenarios), Y = RTFx (log scale). Frontier models are highlighted with star markers + name labels and connected by a dashed line; non-frontier models render as light dots. """ if df.empty: return _empty_fig("No leaderboard data yet.") need = {"model_id", "avg_wer", "eval_rtf"} if not need.issubset(df.columns): return _empty_fig("Missing Average WER or RTFx columns for the Pareto chart.") d = df[["model_id", "avg_wer", "eval_rtf"]].copy() d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce") d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce") d = d.dropna(subset=["avg_wer", "eval_rtf"]) d = d[d["eval_rtf"] > 0] if d.empty: return _empty_fig( "No models with Average WER and RTFx. Re-run evaluations to populate timing." ) d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v))) d = d.dropna(subset=["wer_pct"]).reset_index(drop=True) # Layer-1 frontier on (avg_wer, eval_rtf): lower WER + higher RTF is better. fx_raw, fy = _pareto_efficient_frontier_wer_rtf( d["avg_wer"].tolist(), d["eval_rtf"].tolist(), ) frontier_pairs = set(zip([round(w, 8) for w in fx_raw], [round(r, 8) for r in fy])) def _is_frontier(row) -> bool: return (round(float(row["avg_wer"]), 8), round(float(row["eval_rtf"]), 8)) in frontier_pairs d["_frontier"] = d.apply(_is_frontier, axis=1) frontier_df = d[d["_frontier"]].sort_values("wer_pct").reset_index(drop=True) other_df = d[~d["_frontier"]] accent = "#1f77ff" # cool blue (matches screenshot) accent_light = "rgba(31, 119, 255, 0.30)" fig = go.Figure() if not other_df.empty: fig.add_trace( go.Scatter( x=other_df["wer_pct"], y=other_df["eval_rtf"], mode="markers", name="Other models", marker=dict(size=9, color=accent_light, line=dict(width=0, color="white")), text=other_df["model_id"], hovertemplate=( "%{text}
" "Average WER: %{x:.2f}%
" "RTFx: %{y:.2f}×" ), ) ) if not frontier_df.empty: fig.add_trace( go.Scatter( x=frontier_df["wer_pct"], y=frontier_df["eval_rtf"], mode="lines+markers+text", name="Pareto frontier", line=dict(color=accent, width=2, dash="dash"), marker=dict(size=14, color=accent, symbol="star", line=dict(width=0, color="white")), text=frontier_df["model_id"], textposition="top center", textfont=dict(size=11, color=accent), cliponaxis=False, hovertemplate=( "%{text}
" "Average WER: %{x:.2f}%
" "RTFx: %{y:.2f}×" ), ) ) ttl = title or "Pareto Front: Average WER vs RTFx" rtf_min = float(d["eval_rtf"].min()) rtf_max = float(d["eval_rtf"].max()) log_y_lo = np.log10(max(rtf_min * 0.85, 1e-4)) log_y_hi = np.log10(rtf_max * 1.2) fig.update_layout( title=dict(text=ttl, x=0.0, xanchor="left", font=dict(size=18)), xaxis=dict( title="Average WER (lower is better)", ticksuffix="", zeroline=False, ), template=_TEMPLATE, height=_FIG_HEIGHT, autosize=True, margin=dict(l=70, r=40, t=70, b=60), showlegend=True, legend=dict( yanchor="top", y=0.98, xanchor="right", x=0.99, bgcolor="rgba(255,255,255,0.7)", borderwidth=0, ), ) fig.update_yaxes( title="RTFx (higher is better)", type="log", zeroline=False, range=[log_y_lo, log_y_hi], tickformat=".2g", ) fig.update_xaxes( title="Average WER (lower is better)", zeroline=False, range=[0, 100], ) return fig def plot_scenario_bar_summary(df: pd.DataFrame, top_n: int = 8) -> go.Figure: """Grouped bars: one trace per model; legend toggles models.""" if df.empty: return _empty_fig("No leaderboard data yet.") cols = [c for c in HEATMAP_SCENARIO_KEYS if c in df.columns] if not cols: return _empty_fig("No core WER columns.") d = df.copy() if "avg_wer" in d.columns: d = d.sort_values("avg_wer", ascending=True, na_position="last").head(int(top_n)) else: d["_avg"] = d[cols].mean(axis=1, skipna=True) d = d.sort_values("_avg", ascending=True).head(int(top_n)) scenarios = cols x_labels = [metric_by_key(c).short if metric_by_key(c) else c for c in scenarios] fig = go.Figure() palette = qualitative.Plotly * 3 for i, (_, row) in enumerate(d.iterrows()): heights = [ _wer_pct(row[c]) if pd.notna(row[c]) else None for c in scenarios ] short = row["model_id"].split("/")[-1][:22] fig.add_trace( go.Bar( name=short, x=x_labels, y=heights, marker_color=palette[i % len(palette)], hovertemplate=f"{short}
%{{x}}: %{{y:.2f}}%", ) ) fig.update_layout( barmode="group", title=dict( text=f"WER by scenario: top {len(d)} models by Average WER (legend toggles models)", x=0.5, xanchor="center", ), xaxis=dict(title="Scenario", tickangle=-20), yaxis=dict(title="WER (%)", rangemode="tozero"), template=_TEMPLATE, height=_FIG_HEIGHT, margin=dict(l=60, r=120, t=60, b=100), **_plotly_legend_config(), ) return fig def plot_lines_by_rank(df: pd.DataFrame, metric_keys: Sequence[str], title: str | None = None) -> go.Figure: """Deprecated: use plot_clean_vs_reverb_scatter or plot_latency_vs_wer.""" return plot_clean_vs_reverb_scatter(df, title=title or "Clean vs reverberant WER")