Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """ | |
| Interactive Plotly charts for the Analysis tab (legend toggles traces, zoom, pan, hover). | |
| Radar = compare models on normalized axes. Middle chart = WER across scenarios (one trace per model). | |
| Bar chart = grouped WER by scenario for top models. | |
| """ | |
| from __future__ import annotations | |
| import threading | |
| from typing import Sequence | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from plotly.colors import qualitative | |
| from metrics_config import ( | |
| HEATMAP_SCENARIO_KEYS, | |
| LIVE_SCENARIO_KEYS, | |
| SCENARIO_METRICS, | |
| heatmap_label_for_key, | |
| metric_by_key, | |
| ) | |
| # Mean WER over realistic SNR splits (Pareto X-axis and ranking). | |
| SNR_WER_KEYS: tuple[str, ...] = ( | |
| "wer_realistic_high_snr", | |
| "wer_realistic_mid_snr", | |
| "wer_realistic_low_snr", | |
| ) | |
| def _wer_pct(v) -> float | None: | |
| """Fractional WER (0–1) → percent for chart display.""" | |
| if v is None or (isinstance(v, float) and not np.isfinite(v)): | |
| return None | |
| try: | |
| f = float(v) | |
| except (TypeError, ValueError): | |
| return None | |
| if not np.isfinite(f): | |
| return None | |
| return f * 100.0 | |
| def _pareto_efficient_frontier_wer_rtf( | |
| wers: Sequence[float], rtfs: Sequence[float] | |
| ) -> tuple[list[float], list[float]]: | |
| """Layer-1 envelope: X = WER (lower better), Y = RTF (higher better).""" | |
| pts = [ | |
| (float(w), float(r)) | |
| for w, r in zip(wers, rtfs) | |
| if np.isfinite(w) and np.isfinite(r) | |
| ] | |
| if not pts: | |
| return [], [] | |
| frontier: list[tuple[float, float]] = [] | |
| for i, (wi, ri) in enumerate(pts): | |
| dominated = False | |
| for j, (wj, rj) in enumerate(pts): | |
| if i == j: | |
| continue | |
| if wj <= wi and rj >= ri and (wj < wi or rj > ri): | |
| dominated = True | |
| break | |
| if not dominated: | |
| frontier.append((wi, ri)) | |
| frontier.sort(key=lambda p: p[0]) | |
| return [p[0] for p in frontier], [p[1] for p in frontier] | |
| # Consistent height for Gradio Plot | |
| _FIG_HEIGHT = 460 | |
| _TEMPLATE = "plotly_white" | |
| # Weight-file size lookup for the "Memory" radar axis. Populated lazily on first render; | |
| # falls back to NaN if the hub call fails (e.g. gated repo without token). | |
| _model_size_cache: dict[str, float] = {} | |
| _model_size_lock = threading.Lock() | |
| _WEIGHT_EXTS = (".safetensors", ".bin", ".pt", ".pth", ".gguf", ".ckpt", ".msgpack", ".h5") | |
| def _cached_model_size_mb(model_id: str) -> float: | |
| with _model_size_lock: | |
| if model_id in _model_size_cache: | |
| return _model_size_cache[model_id] | |
| total = 0 | |
| try: | |
| from huggingface_hub import HfApi | |
| info = HfApi().model_info(model_id, files_metadata=True) | |
| for s in info.siblings or []: | |
| name = getattr(s, "rfilename", "") or "" | |
| size = getattr(s, "size", None) or 0 | |
| if name.lower().endswith(_WEIGHT_EXTS): | |
| total += int(size) | |
| except Exception: | |
| total = 0 | |
| mb = float(total) / (1024.0 * 1024.0) if total > 0 else float("nan") | |
| with _model_size_lock: | |
| _model_size_cache[model_id] = mb | |
| return mb | |
| def _coerce_wer(v) -> float | None: | |
| """Non-negative finite WER, or None if missing / invalid.""" | |
| if v is None or v == "": | |
| return None | |
| try: | |
| x = float(v) | |
| return x if np.isfinite(x) and x >= 0 else None | |
| except (TypeError, ValueError): | |
| return None | |
| def _coerce_rtf(v) -> float | None: | |
| """Non-negative finite RTF (audio sec / inference sec), or None if missing.""" | |
| if v is None or v == "": | |
| return None | |
| try: | |
| x = float(v) | |
| return x if np.isfinite(x) and x >= 0 else None | |
| except (TypeError, ValueError): | |
| return None | |
| def snr_avg_wer(row: dict) -> float | None: | |
| """Mean WER over high / mid / low SNR scenarios (skip missing).""" | |
| vals: list[float] = [] | |
| for k in SNR_WER_KEYS: | |
| w = _coerce_wer(row.get(k)) | |
| if w is not None: | |
| vals.append(w) | |
| if not vals: | |
| return None | |
| return float(sum(vals) / len(vals)) | |
| def compute_pareto_layers(rows: list[dict]) -> dict[str, int | None]: | |
| """ | |
| Pareto non-dominated layers over (snr_avg_wer ↓, eval_rtf ↑). | |
| Layer 1 = efficient frontier; peel and repeat. Missing WER or RTF → ``None``. | |
| """ | |
| pts: list[tuple[str, float, float]] = [] | |
| out: dict[str, int | None] = {} | |
| for r in rows: | |
| mid = (r.get("model_id") or "").strip() | |
| if not mid: | |
| continue | |
| w = snr_avg_wer(r) | |
| rtf = _coerce_rtf(r.get("eval_rtf")) | |
| if w is None or rtf is None: | |
| out[mid] = None | |
| else: | |
| pts.append((mid, w, rtf)) | |
| unassigned = list(pts) | |
| layer = 1 | |
| while unassigned: | |
| frontier: list[str] = [] | |
| for i, (mid_i, w_i, r_i) in enumerate(unassigned): | |
| dominated = False | |
| for j, (_, w_j, r_j) in enumerate(unassigned): | |
| if i == j: | |
| continue | |
| if w_j <= w_i and r_j >= r_i and (w_j < w_i or r_j > r_i): | |
| dominated = True | |
| break | |
| if not dominated: | |
| frontier.append(mid_i) | |
| frontier_set = set(frontier) | |
| for mid in frontier: | |
| out[mid] = layer | |
| unassigned = [p for p in unassigned if p[0] not in frontier_set] | |
| layer += 1 | |
| return out | |
| def _avg_wer_for_row(row: dict) -> float: | |
| """Mean WER over live scenarios; ``inf`` when nothing populated.""" | |
| vals: list[float] = [] | |
| for k in LIVE_SCENARIO_KEYS: | |
| w = _coerce_wer(row.get(k)) | |
| if w is not None: | |
| vals.append(w) | |
| if not vals: | |
| return float("inf") | |
| return sum(vals) / len(vals) | |
| def sort_leaderboard_rows_inplace(rows: list[dict]) -> None: | |
| """Sort rows by Average WER ascending (lower WER first).""" | |
| rows.sort(key=lambda r: (_avg_wer_for_row(r), (r.get("model_id") or ""))) | |
| def frontier_label(layer: int | None) -> str: | |
| return str(layer) if layer is not None else "NA" | |
| def _log_normalize(values: np.ndarray, *, inverse: bool, strength_floor: float = 0.1) -> np.ndarray: | |
| """ | |
| Log10-normalize an array across all models to a "strength" score (default display floor 0.1). | |
| * `inverse=True` -> lower raw value is better (WER, memory). Strength = 1 - norm. | |
| * `inverse=False` -> higher raw value is better (RTF). Strength = norm. | |
| Returns NaN for entries whose raw value is missing or non-positive. | |
| """ | |
| v = np.asarray(values, dtype=float) | |
| out = np.full_like(v, np.nan, dtype=float) | |
| mask = np.isfinite(v) & (v > 0) | |
| if not mask.any(): | |
| return out | |
| logv = np.log10(np.maximum(v[mask], 1e-12)) | |
| lo = float(np.min(logv)) | |
| hi = float(np.max(logv)) | |
| if hi - lo < 1e-9: | |
| out[mask] = 0.5 | |
| else: | |
| out[mask] = (logv - lo) / (hi - lo) | |
| if inverse: | |
| out[mask] = 1.0 - out[mask] | |
| fin = np.isfinite(out) | |
| out[fin] = np.clip(np.maximum(out[fin], strength_floor), strength_floor, 1.0) | |
| return out | |
| # Axis spec: (strength_label, primary col, fallback col, inverse, raw_display_name, raw_unit) | |
| # strength_label: chart axis label; always phrased so "farther = better". | |
| # inverse : True if *lower* raw value is better (e.g. WER, memory); False if higher is better (RTF). | |
| # raw_display_name: human-readable name of the underlying metric, shown in the hover tooltip. | |
| # raw_unit : optional unit suffix for the raw value in the tooltip. | |
| # | |
| # Live far-field scenario axes + Speed (RTF) + Compactness (parameters / Hub size). | |
| _RADAR_AXES: tuple[tuple[str, str, str | None, bool, str, str], ...] = ( | |
| ("Near Field Speech", "wer_anechoic_speech", None, True, "Near field speech WER", ""), | |
| ("Lab Measured", "wer_lab_measured", None, True, "Lab measured WER", ""), | |
| ("Lab Simulated", "wer_lab_simulated", None, True, "Lab simulated WER", ""), | |
| ("High SNR", "wer_realistic_high_snr", None, True, "High SNR WER", ""), | |
| ("Mid SNR", "wer_realistic_mid_snr", None, True, "Mid SNR WER", ""), | |
| ("Low SNR", "wer_realistic_low_snr", None, True, "Low SNR WER", ""), | |
| ("Moving Sources", "wer_moving_sources", None, True, "Moving Sources WER", ""), | |
| ("Speed", "eval_rtf", None, False, "RTFx", "× realtime"), | |
| # Compactness: smaller model -> more compact -> farther from centre. | |
| # `inverse=True` makes `_log_normalize` emit `1 - norm`, so the smallest | |
| # observed model maps to strength = 1.0 (radar edge) and the largest | |
| # maps to 0.0 (centre). | |
| # | |
| # Primary signal: parameter count (millions), which is recorded for every | |
| # backend by `attach_params` and is much more reliable than the Hub | |
| # weight-file size lookup (which can return NaN for gated/private repos | |
| # or repos that ship weights via LFS pointers without size metadata). | |
| # `model_size_mb` is the fallback for legacy rows that pre-date param | |
| # tracking. | |
| ("Compactness", "num_params_m", "model_size_mb", True, "Parameters", " M"), | |
| ) | |
| # Hover-label override when a radar axis falls back from `primary` to | |
| # `fallback`. Without this the tooltip would still claim the value is in the | |
| # primary metric (e.g. "Parameters: 1500.0000 M") when in fact the rendered | |
| # value came from the fallback column (`Model size: 1500.0000 MB`). | |
| _RADAR_FALLBACK_DISPLAY: dict[str, tuple[str, str]] = { | |
| "num_params_m": ("Model size", " MB"), | |
| } | |
| def _empty_fig(message: str) -> go.Figure: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=message, | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=14), | |
| ) | |
| fig.update_layout( | |
| template=_TEMPLATE, | |
| height=400, | |
| margin=dict(l=40, r=40, t=40, b=40), | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| ) | |
| return fig | |
| def _plotly_legend_config() -> dict: | |
| """Legend click toggles traces; double-click isolates.""" | |
| return dict( | |
| legend=dict( | |
| orientation="v", | |
| yanchor="top", | |
| y=1, | |
| xanchor="left", | |
| x=1.02, | |
| font=dict(size=10), | |
| ), | |
| hovermode="closest", | |
| ) | |
| def _raw_to_analytics_df(raw: list[dict]) -> pd.DataFrame: | |
| """ | |
| Parse leaderboard rows; coerce scenario WER + timing columns to float (NaN if missing). | |
| Computes: | |
| * ``avg_wer``: mean of live scenario WERs present for that row (reference). | |
| * ``model_size_mb``: total weight-file size fetched from the Hub (cached). | |
| """ | |
| from init import normalize_legacy_csv_row | |
| if not raw: | |
| return pd.DataFrame() | |
| rows = [normalize_legacy_csv_row(dict(r)) for r in raw] | |
| df = pd.DataFrame(rows) | |
| if "model_id" not in df.columns: | |
| return pd.DataFrame() | |
| for m in SCENARIO_METRICS: | |
| k = m.key | |
| if k not in df.columns: | |
| df[k] = np.nan | |
| else: | |
| df[k] = pd.to_numeric(df[k], errors="coerce") | |
| for k in ("eval_wall_time_s", "eval_rtf", "eval_audio_seconds", "num_params"): | |
| if k not in df.columns: | |
| df[k] = np.nan | |
| else: | |
| df[k] = pd.to_numeric(df[k], errors="coerce") | |
| # Millions of parameters; compact display. | |
| df["num_params_m"] = df["num_params"] / 1e6 | |
| live_cols = [c for c in LIVE_SCENARIO_KEYS if c in df.columns] | |
| df["avg_wer"] = df[live_cols].mean(axis=1, skipna=True) if live_cols else np.nan | |
| layers = compute_pareto_layers(rows) | |
| df["pareto_layer"] = [ | |
| layers.get((r.get("model_id") or "").strip()) for r in rows | |
| ] | |
| df["snr_avg_wer"] = [snr_avg_wer(r) for r in rows] | |
| df["model_size_mb"] = df["model_id"].apply(_cached_model_size_mb) | |
| return df | |
| def _sort_df_leaderboard_order(df: pd.DataFrame) -> pd.DataFrame: | |
| """Average WER ascending (lower is better).""" | |
| if df.empty: | |
| return df | |
| d = df.copy() | |
| if "avg_wer" in d.columns: | |
| return d.sort_values("avg_wer", ascending=True, na_position="last") | |
| return d | |
| def available_metric_keys(df: pd.DataFrame) -> list[str]: | |
| """Scenario keys that have at least one non-NaN value.""" | |
| keys = [] | |
| for m in SCENARIO_METRICS: | |
| if m.key in df.columns and df[m.key].notna().any(): | |
| keys.append(m.key) | |
| return keys | |
| # Brand-aligned colors per HF org / company. Lower-cased org prefix → hex. | |
| # Unknown orgs fall back to a deterministic palette so that the same company | |
| # keeps the same color across the Intelligence / Speed charts. | |
| _COMPANY_COLORS: dict[str, str] = { | |
| "nvidia": "#76B900", # NVIDIA green | |
| "openai": "#111111", # OpenAI black | |
| "openai-community": "#111111", | |
| "qwen": "#615CED", # Qwen purple | |
| "ibm-granite": "#0F62FE", # IBM blue | |
| "ibm": "#0F62FE", | |
| "coherelabs": "#39594D", # Cohere muted green | |
| "cohere": "#39594D", | |
| "efficient-speech": "#F2994A", | |
| "usefulsensors": "#22A7F0", | |
| "google": "#4285F4", # Google blue | |
| "google-t5": "#4285F4", | |
| "meta-llama": "#1877F2", # Meta blue | |
| "facebook": "#1877F2", | |
| "microsoft": "#00A4EF", | |
| "deepseek-ai": "#1D6FE0", | |
| "deepseek": "#1D6FE0", | |
| "moonshotai": "#111111", | |
| "mistralai": "#FF6F00", | |
| "anthropic": "#CC9B73", | |
| "x-ai": "#000000", | |
| "xai": "#000000", | |
| "zhipuai": "#1F8AFF", | |
| "zai-org": "#1F8AFF", | |
| } | |
| _FALLBACK_PALETTE: tuple[str, ...] = ( | |
| "#1F77B4", "#FF7F0E", "#2CA02C", "#D62728", "#9467BD", | |
| "#8C564B", "#E377C2", "#7F7F7F", "#BCBD22", "#17BECF", | |
| ) | |
| def _company_key(model_id: str) -> str: | |
| """Normalized company key from a HF model id (`org/name` → `org`).""" | |
| mid = (model_id or "").strip() | |
| if "/" in mid: | |
| return mid.split("/", 1)[0].lower() | |
| return mid.lower() | |
| def _company_color_map(model_ids: Sequence[str]) -> dict[str, str]: | |
| """Stable model_id → hex color mapping; same company shares a color.""" | |
| out: dict[str, str] = {} | |
| fallback_assign: dict[str, str] = {} | |
| unknown_companies: list[str] = [] | |
| for mid in model_ids: | |
| comp = _company_key(mid) | |
| if comp in _COMPANY_COLORS: | |
| out[mid] = _COMPANY_COLORS[comp] | |
| else: | |
| if comp not in fallback_assign: | |
| unknown_companies.append(comp) | |
| out[mid] = "" # filled in below | |
| # Deterministic palette assignment for unknown orgs (sorted alphabetically). | |
| for i, comp in enumerate(sorted(set(unknown_companies))): | |
| fallback_assign[comp] = _FALLBACK_PALETTE[i % len(_FALLBACK_PALETTE)] | |
| for mid in model_ids: | |
| if not out[mid]: | |
| out[mid] = fallback_assign[_company_key(mid)] | |
| return out | |
| def _short_model_label(model_id: str, max_len: int = 26) -> str: | |
| name = (model_id or "").split("/", 1)[-1] | |
| return name if len(name) <= max_len else name[: max_len - 1] + "…" | |
| def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]: | |
| h = hex_color.lstrip("#") | |
| return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) | |
| def _mix(c: int, target: int, t: float) -> int: | |
| return int(round(c + (target - c) * max(0.0, min(1.0, t)))) | |
| def _lighten(hex_color: str, t: float) -> str: | |
| """t in [0,1] -> mix toward white (t=0 returns the original color).""" | |
| r, g, b = _hex_to_rgb(hex_color) | |
| return "#{:02X}{:02X}{:02X}".format(_mix(r, 255, t), _mix(g, 255, t), _mix(b, 255, t)) | |
| def _darken(hex_color: str, t: float) -> str: | |
| """t in [0,1] -> mix toward black (t=0 returns the original color).""" | |
| r, g, b = _hex_to_rgb(hex_color) | |
| return "#{:02X}{:02X}{:02X}".format(_mix(r, 0, t), _mix(g, 0, t), _mix(b, 0, t)) | |
| def _to_rgba(hex_color: str, alpha: float) -> str: | |
| r, g, b = _hex_to_rgb(hex_color) | |
| return f"rgba({r},{g},{b},{alpha:.3f})" | |
| def _company_shade_assignments( | |
| model_ids: Sequence[str], | |
| color_map: dict[str, str], | |
| ) -> tuple[list[str], list[str], list[float]]: | |
| """ | |
| For bars whose company appears multiple times in this chart, vary the | |
| luminance + opacity so adjacent same-color bars remain distinguishable. | |
| Returns parallel lists (one entry per model_id): | |
| * fill rgba color (with per-bar alpha baked in) | |
| * border hex color (a darker shade of the base brand color) | |
| * text opacity (kept ≥ 0.85 so in-bar labels stay legible) | |
| """ | |
| counts: dict[str, int] = {} | |
| for mid in model_ids: | |
| counts[_company_key(mid)] = counts.get(_company_key(mid), 0) + 1 | |
| seen: dict[str, int] = {} | |
| fills: list[str] = [] | |
| borders: list[str] = [] | |
| text_alphas: list[float] = [] | |
| for mid in model_ids: | |
| base = color_map[mid] | |
| comp = _company_key(mid) | |
| n = counts[comp] | |
| idx = seen.get(comp, 0) | |
| seen[comp] = idx + 1 | |
| if n > 1: | |
| # First instance keeps base color; subsequent ones lighten ~12% per step | |
| # (capped at ~45%) and reduce alpha slightly so the boundary is obvious | |
| # even when two same-company bars sit next to each other. | |
| lighten_t = min(0.45, 0.12 * idx) | |
| alpha = max(0.65, 1.0 - 0.08 * idx) | |
| shaded = _lighten(base, lighten_t) | |
| fills.append(_to_rgba(shaded, alpha)) | |
| borders.append(_darken(base, 0.35)) | |
| text_alphas.append(max(0.9, 1.0 - 0.05 * idx)) | |
| else: | |
| fills.append(_to_rgba(base, 1.0)) | |
| borders.append(_darken(base, 0.35)) | |
| text_alphas.append(1.0) | |
| return fills, borders, text_alphas | |
| def _plot_company_bars( | |
| labels: Sequence[str], | |
| values: Sequence[float], | |
| model_ids: Sequence[str], | |
| *, | |
| title: str, | |
| subtitle: str, | |
| value_fmt: str, | |
| hover_value_fmt: str, | |
| hover_metric_label: str, | |
| ) -> go.Figure: | |
| """ | |
| Shared renderer for the Intelligence / Speed vertical bar charts. | |
| - Bars are colored by HF org via `_company_color_map`. | |
| - Same-company bars in the same chart get progressively lighter fills + | |
| slightly reduced alpha so they remain distinguishable. | |
| - Each bar has a thin darker outline ("corner line") tracing the rounded | |
| corners so it stands out from the white background. | |
| - Values are drawn inside the bar (white) with rounded corners. | |
| - Hover shows the full model id and the metric value. | |
| """ | |
| color_map = _company_color_map(list(model_ids)) | |
| fills, borders, text_alphas = _company_shade_assignments( | |
| list(model_ids), color_map | |
| ) | |
| text_in = [value_fmt.format(v) for v in values] | |
| text_colors = [f"rgba(255,255,255,{a:.3f})" for a in text_alphas] | |
| # Use integer x-positions so duplicate display labels (e.g. two models that | |
| # share the same truncated short name) don't collapse into a single stacked | |
| # bar. Tick labels are mapped back onto these positions below. | |
| x_positions = list(range(len(labels))) | |
| fig = go.Figure( | |
| go.Bar( | |
| x=x_positions, | |
| y=values, | |
| customdata=list(model_ids), | |
| marker=dict( | |
| color=fills, | |
| line=dict(color=borders, width=1.6), | |
| cornerradius=8, | |
| ), | |
| text=text_in, | |
| textposition="inside", | |
| insidetextanchor="middle", | |
| textfont=dict( | |
| color=text_colors, size=13, family="Inter, Arial, sans-serif" | |
| ), | |
| hovertemplate=( | |
| "<b>%{customdata}</b><br>" | |
| f"{hover_metric_label}: {hover_value_fmt}<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| max_v = max([float(v) for v in values if v is not None] or [1.0]) | |
| fig.update_layout( | |
| title=dict( | |
| text=( | |
| f"<b>{title}</b>" | |
| f"<br><span style='font-size:12px;color:#6b7280'>{subtitle}</span>" | |
| ), | |
| x=0.0, | |
| xanchor="left", | |
| font=dict(size=18), | |
| ), | |
| xaxis=dict( | |
| tickmode="array", | |
| tickvals=x_positions, | |
| ticktext=list(labels), | |
| tickangle=-40, | |
| automargin=True, | |
| showgrid=False, | |
| showline=False, | |
| ), | |
| yaxis=dict( | |
| visible=False, | |
| range=[0, max_v * 1.18], | |
| ), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| margin=dict(l=20, r=20, t=80, b=120), | |
| showlegend=False, | |
| bargap=0.25, | |
| plot_bgcolor="white", | |
| ) | |
| return fig | |
| def plot_avg_wer_bars( | |
| df: pd.DataFrame, | |
| top_n: int = 10, | |
| ) -> go.Figure: | |
| """ | |
| Ranked vertical bars of average WER (``avg_wer``) as a percent. | |
| Top models sorted ascending (lower WER is better); value inside the bar, | |
| rounded corners, one color per HF org. | |
| """ | |
| if df.empty or "avg_wer" not in df.columns: | |
| return _empty_fig("No leaderboard WER data yet.") | |
| d = df[["model_id", "avg_wer"]].dropna(subset=["avg_wer"]).copy() | |
| if d.empty: | |
| return _empty_fig("No WER values computed.") | |
| d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce") | |
| d = d.dropna(subset=["avg_wer"]) | |
| d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v))) | |
| d = d.dropna(subset=["wer_pct"]) | |
| d = d.sort_values("wer_pct", ascending=True).head(max(1, int(top_n))) | |
| labels = [_short_model_label(m) for m in d["model_id"].tolist()] | |
| return _plot_company_bars( | |
| labels=labels, | |
| values=[float(v) for v in d["wer_pct"].tolist()], | |
| model_ids=d["model_id"].tolist(), | |
| title="WER", | |
| subtitle="Average WER (%) · Lower is better", | |
| value_fmt="{:.1f}", | |
| hover_value_fmt="%{y:.2f}%", | |
| hover_metric_label="Average WER", | |
| ) | |
| def plot_speed_bars( | |
| df: pd.DataFrame, | |
| top_n: int = 10, | |
| ) -> go.Figure: | |
| """ | |
| Ranked vertical bars of throughput (`eval_rtf`, × realtime). | |
| Companion to `plot_avg_wer_bars`: same color-per-company convention, | |
| same in-bar value labels, rounded corners. | |
| """ | |
| if df.empty or "eval_rtf" not in df.columns: | |
| return _empty_fig("No RTF measurements yet.") | |
| d = df[["model_id", "eval_rtf"]].copy() | |
| d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce") | |
| d = d.dropna(subset=["eval_rtf"]) | |
| d = d[d["eval_rtf"] > 0] | |
| if d.empty: | |
| return _empty_fig("No RTF measurements yet.") | |
| d = d.sort_values("eval_rtf", ascending=False).head(max(1, int(top_n))) | |
| labels = [_short_model_label(m) for m in d["model_id"].tolist()] | |
| values = [float(v) for v in d["eval_rtf"].tolist()] | |
| # Show ints when large, one decimal when small (matches AA "314" / "0.7" feel). | |
| fmt = "{:.0f}" if max(values) >= 10 else "{:.1f}" | |
| return _plot_company_bars( | |
| labels=labels, | |
| values=values, | |
| model_ids=d["model_id"].tolist(), | |
| title="Speed", | |
| subtitle="RTFx (audio sec / inference sec) · Higher is better", | |
| value_fmt=fmt, | |
| hover_value_fmt="%{y:.2f}×", | |
| hover_metric_label="RTFx", | |
| ) | |
| def plot_leaderboard_score_bars( | |
| df: pd.DataFrame, | |
| top_n: int = 40, | |
| title: str | None = None, | |
| ) -> go.Figure: | |
| """ | |
| Ranked horizontal bar chart of average WER (percent). | |
| """ | |
| if df.empty or "avg_wer" not in df.columns: | |
| return _empty_fig("No leaderboard WER data yet.") | |
| cols = ["model_id", "avg_wer"] | |
| if "pareto_layer" in df.columns: | |
| cols.append("pareto_layer") | |
| d = df[cols].dropna(subset=["avg_wer"]).copy() | |
| if d.empty: | |
| return _empty_fig("No WER values computed.") | |
| n = max(1, int(top_n)) | |
| d = _sort_df_leaderboard_order(d).head(n) | |
| labels = d["model_id"].str.split("/").str[-1].str[:42] | |
| scores = d["avg_wer"].astype(float).apply(lambda v: _wer_pct(v) or 0.0) | |
| fig = go.Figure( | |
| go.Bar( | |
| x=scores, | |
| y=labels, | |
| orientation="h", | |
| marker=dict( | |
| color=scores, | |
| colorscale="RdYlGn_r", | |
| showscale=True, | |
| colorbar=dict(title="WER (%)"), | |
| ), | |
| text=[f"{s:.1f}" for s in scores], | |
| textposition="outside", | |
| hovertemplate="<b>%{y}</b><br>Average WER: %{x:.2f}%<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=title or f"Average WER (top {len(d)} models)", | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| xaxis=dict(title="WER (%)", rangemode="tozero"), | |
| yaxis=dict(autorange="reversed"), | |
| template=_TEMPLATE, | |
| height=int(min(920, max(_FIG_HEIGHT, 100 + 22 * len(d)))), | |
| margin=dict(l=200, r=100, t=60, b=50), | |
| ) | |
| return fig | |
| def _radar_absolute_value(row: pd.Series, axis: str) -> tuple[float, str]: | |
| """Map a leaderboard row to an absolute 0–1 radar coordinate (higher = better).""" | |
| if axis == "WER": | |
| wer = pd.to_numeric(row.get("avg_wer"), errors="coerce") | |
| if not np.isfinite(wer) or float(wer) < 0: | |
| return 0.0, "Average WER: N/A" | |
| w = float(wer) | |
| pct = w * 100.0 | |
| v = float(max(0.0, min(1.0, 1.0 / (1.0 + w)))) | |
| return v, f"Average WER: {pct:.2f}% → strength {v:.3f}" | |
| if axis == "Speed": | |
| rtf = pd.to_numeric(row.get("eval_rtf"), errors="coerce") | |
| if not np.isfinite(rtf) or float(rtf) < 0: | |
| return 0.0, "RTFx: N/A" | |
| r = float(rtf) | |
| v = float(max(0.0, min(1.0, r / (1.0 + r)))) | |
| return v, f"RTFx: {r:.4f}× → speed {v:.3f} (RTFx/(1+RTFx))" | |
| if axis == "Compactness": | |
| pm = pd.to_numeric(row.get("num_params_m"), errors="coerce") | |
| if not np.isfinite(pm) or float(pm) < 0: | |
| return 0.0, "Parameters: N/A" | |
| p = float(pm) | |
| v = float(max(0.0, min(1.0, 1.0 / (1.0 + p / 1000.0)))) | |
| return v, f"Parameters: {p:.2f} M → compactness {v:.3f}" | |
| return 0.0, f"{axis}: N/A" | |
| _RADAR_ABSOLUTE_AXES: tuple[str, ...] = ("WER", "Speed", "Compactness") | |
| def plot_robustness_radar( | |
| df: pd.DataFrame, | |
| model_ids: Sequence[str], | |
| title: str = "Robustness radar (absolute 0–1; outward is better)", | |
| ) -> go.Figure: | |
| """ | |
| Three-axis radar with **absolute** coordinates in [0, 1] (not relative to other models). | |
| WER strength = 1 / (1 + avg_wer); Speed = RTFx / (1 + RTFx); | |
| Compactness = 1 / (1 + num_params_m / 1000). Missing values map to 0. | |
| """ | |
| if df.empty: | |
| return _empty_fig("No leaderboard data yet.") | |
| d = df.copy().reset_index(drop=True) | |
| labels = list(_RADAR_ABSOLUTE_AXES) | |
| labels_closed = labels + [labels[0]] | |
| selected = [m for m in (model_ids or []) if m in set(d["model_id"])] | |
| if not selected: | |
| selected = d["model_id"].tolist()[: min(5, len(d))] | |
| fig = go.Figure() | |
| palette = qualitative.Plotly * 3 | |
| for i, mid in enumerate(selected): | |
| rows = d.index[d["model_id"] == mid].tolist() | |
| if not rows: | |
| continue | |
| row = d.iloc[rows[0]] | |
| r_vals: list[float] = [] | |
| raw_text: list[str] = [] | |
| for ax in labels: | |
| v, tip = _radar_absolute_value(row, ax) | |
| r_vals.append(v) | |
| raw_text.append(tip) | |
| r_closed = r_vals + [r_vals[0]] | |
| raw_closed = raw_text + [raw_text[0]] | |
| short = mid.split("/")[-1][:28] | |
| color = palette[i % len(palette)] | |
| fig.add_trace( | |
| go.Scatterpolar( | |
| r=r_closed, | |
| theta=labels_closed, | |
| customdata=raw_closed, | |
| fill="none", | |
| name=short, | |
| line=dict(width=2.2, color=color), | |
| marker=dict(size=6, color=color), | |
| opacity=1.0, | |
| hovertemplate=( | |
| f"<b>{short}</b><br>" | |
| "%{theta}: %{r:.2f}<br>" | |
| "%{customdata}<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 1], | |
| showticklabels=True, | |
| tickformat=".2f", | |
| tickvals=[0, 0.25, 0.5, 0.75, 1.0], | |
| ), | |
| angularaxis=dict(direction="clockwise", rotation=90), | |
| ), | |
| title=dict(text=title, x=0.5, xanchor="center"), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| margin=dict(l=50, r=120, t=60, b=50), | |
| showlegend=True, | |
| **_plotly_legend_config(), | |
| ) | |
| return fig | |
| def plot_radar_compare( | |
| df: pd.DataFrame, | |
| model_ids: Sequence[str], | |
| metric_keys: Sequence[str] | None = None, | |
| title: str = "Robustness profile (relative strength; outward is better)", | |
| ) -> go.Figure: | |
| """Deprecated alias for older UI bindings; forwards to `plot_robustness_radar`.""" | |
| return plot_robustness_radar(df, model_ids, title=title) | |
| def plot_compare_models_across_scenarios( | |
| df: pd.DataFrame, | |
| model_ids: Sequence[str], | |
| metric_keys: Sequence[str], | |
| title: str = "WER by model (one trace per scenario; use legend to toggle)", | |
| ) -> go.Figure: | |
| """ | |
| Line chart: x = model (sorted by Average WER), y = WER, one trace per scenario. | |
| Scaling to more models now just extends the x axis instead of adding more traces, | |
| which keeps the legend small (≤ #scenarios) and makes cross-scenario comparisons | |
| for the same model trivial (read a vertical slice). | |
| `model_ids` filters which models appear on the x axis; empty / None → all rows. | |
| """ | |
| metric_keys = [k for k in metric_keys if k in df.columns] | |
| if df.empty or not metric_keys: | |
| return _empty_fig("Not enough data.") | |
| selected = [m for m in (model_ids or []) if m] | |
| sub = df[df["model_id"].isin(selected)].copy() if selected else df.copy() | |
| if sub.empty: | |
| return _empty_fig("No matching models.") | |
| # Sort models by Average WER (lower is better). | |
| if "avg_wer" in sub.columns: | |
| sub = sub.sort_values("avg_wer", ascending=True, na_position="last") | |
| sub = sub.reset_index(drop=True) | |
| x_labels = sub["model_id"].str.split("/").str[-1].str[:32].tolist() | |
| full_ids = sub["model_id"].tolist() | |
| fig = go.Figure() | |
| palette = qualitative.Plotly * 3 | |
| for i, k in enumerate(metric_keys): | |
| m = metric_by_key(k) | |
| scen_label = m.short if m else k | |
| ys = [ | |
| _wer_pct(v) if pd.notna(v) else None for v in sub[k].tolist() | |
| ] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x_labels, | |
| y=ys, | |
| mode="lines+markers", | |
| name=scen_label, | |
| customdata=full_ids, | |
| line=dict(width=2, color=palette[i % len(palette)]), | |
| marker=dict(size=8), | |
| hovertemplate=( | |
| f"<b>%{{customdata}}</b><br>{scen_label}: %{{y:.2f}}%<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| n_models = len(x_labels) | |
| tickangle = -45 if n_models > 8 else -25 | |
| fig.update_layout( | |
| title=dict(text=title, x=0.5, xanchor="center"), | |
| xaxis=dict(title="Model (sorted by Average WER)", tickangle=tickangle, automargin=True), | |
| yaxis=dict(title="WER (%) — lower is better", rangemode="tozero"), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| margin=dict(l=60, r=120, t=60, b=140), | |
| **_plotly_legend_config(), | |
| ) | |
| return fig | |
| def plot_scenario_heatmap( | |
| df: pd.DataFrame, | |
| metric_keys: Sequence[str], | |
| top_n: int = 30, | |
| title: str | None = None, | |
| ) -> go.Figure: | |
| """ | |
| Models × scenarios heatmap of WER values (lower = greener). | |
| This replaces the per-scenario line chart for the common case of *N≫few* models | |
| and a small fixed set of conditions: every row is one model and every column is | |
| one condition, so the eye can immediately spot: | |
| - which models do badly on which conditions (hot rows / cells), | |
| - which conditions are universally hard (whole columns red), | |
| - and how the top of the leaderboard compares to the rest. | |
| `metric_keys` selects which scenario columns to show (in given order). Models | |
| are sorted by Average WER and capped at `top_n` rows so the chart | |
| stays legible even with hundreds of leaderboard entries. | |
| """ | |
| metric_keys = [k for k in metric_keys if k in df.columns] | |
| if df.empty or not metric_keys: | |
| return _empty_fig("Not enough data for the heatmap.") | |
| d = _sort_df_leaderboard_order(df.copy()) | |
| d = d.head(max(1, int(top_n))).reset_index(drop=True) | |
| # Omit scenarios with no data in the visible rows (avoids sparse heatmaps). | |
| keep_keys: list[str] = [] | |
| for k in metric_keys: | |
| col = pd.to_numeric(d[k], errors="coerce") | |
| if col.notna().any(): | |
| keep_keys.append(k) | |
| if not keep_keys: | |
| return _empty_fig("No WER values for the selected scenarios.") | |
| order_index = {k: i for i, k in enumerate(HEATMAP_SCENARIO_KEYS)} | |
| keep_keys = sorted(keep_keys, key=lambda k: order_index.get(k, len(order_index))) | |
| x_labels = [heatmap_label_for_key(k) for k in keep_keys] | |
| y_labels = d["model_id"].str.split("/").str[-1].str[:42].tolist() | |
| z: list[list[float]] = [] | |
| cell_text: list[list[str]] = [] | |
| for _, row in d.iterrows(): | |
| z_row: list[float] = [] | |
| t_row: list[str] = [] | |
| for k in keep_keys: | |
| v = pd.to_numeric(row.get(k), errors="coerce") | |
| if pd.isna(v): | |
| z_row.append(float("nan")) | |
| t_row.append("") | |
| else: | |
| pct = _wer_pct(float(v)) | |
| if pct is None: | |
| z_row.append(float("nan")) | |
| t_row.append("") | |
| else: | |
| z_row.append(float(pct)) | |
| t_row.append(f"{pct:.1f}") | |
| z.append(z_row) | |
| cell_text.append(t_row) | |
| flat_vals = [v for row in z for v in row if np.isfinite(v)] | |
| if flat_vals: | |
| raw_max = float(max(flat_vals)) | |
| if raw_max <= 25: | |
| zmax = 25.0 | |
| elif raw_max <= 60: | |
| zmax = float(np.ceil(raw_max / 5.0) * 5.0) | |
| elif raw_max <= 100: | |
| zmax = float(np.ceil(raw_max / 10.0) * 10.0) | |
| else: | |
| zmax = float(np.ceil(raw_max / 25.0) * 25.0) | |
| else: | |
| zmax = 100.0 | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=z, | |
| x=x_labels, | |
| y=y_labels, | |
| text=cell_text, | |
| texttemplate="%{text}", | |
| textfont=dict(size=10, color="#1a1a1a"), | |
| hovertemplate="<b>%{y}</b><br>%{x}: %{z:.2f}%<extra></extra>", | |
| colorscale="RdYlGn_r", | |
| zmin=0.0, | |
| zmax=zmax, | |
| xgap=2, | |
| ygap=2, | |
| colorbar=dict(title="WER (%)", tickformat=".0f"), | |
| ) | |
| ) | |
| n_rows = max(1, len(y_labels)) | |
| # Dynamic height so 30 models gets ~460 px and 80 models gets a taller chart, | |
| # without becoming uselessly tall. | |
| height = int(min(900, max(_FIG_HEIGHT, 80 + 18 * n_rows))) | |
| fig.update_layout( | |
| title=dict( | |
| text=title | |
| or f"WER heatmap: top {n_rows} models × {len(keep_keys)} scenarios (by Average WER)", | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| xaxis=dict( | |
| title=None, | |
| tickangle=-35, | |
| side="top", | |
| automargin=True, | |
| tickfont=dict(size=10), | |
| ), | |
| yaxis=dict(title="Model (sorted by Average WER)", automargin=True, autorange="reversed"), | |
| template=_TEMPLATE, | |
| height=height, | |
| margin=dict(l=80, r=60, t=110, b=60), | |
| ) | |
| # Scenario tick labels sit on top (side="top"); placing the axis title there too | |
| # collides with the plot title, so render the "Scenario" label at the bottom. | |
| fig.add_annotation( | |
| text="Scenario", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=-0.02, | |
| yanchor="top", | |
| showarrow=False, | |
| font=dict(size=13), | |
| ) | |
| return fig | |
| def plot_clean_vs_reverb_scatter( | |
| df: pd.DataFrame, | |
| title: str = "Near Field Speech versus Lab Simulated WER", | |
| ) -> go.Figure: | |
| """ | |
| Scatter: x = near-field speech WER, y = lab-simulated WER. Models often sit above y=x when simulation is harder. | |
| """ | |
| c1, c2 = "wer_anechoic_speech", "wer_lab_simulated" | |
| if df.empty or c1 not in df.columns or c2 not in df.columns: | |
| return _empty_fig("Need near-field speech and lab-simulated WER columns in the leaderboard.") | |
| d = df[["model_id", c1, c2]].copy() | |
| d[c1] = pd.to_numeric(d[c1], errors="coerce") | |
| d[c2] = pd.to_numeric(d[c2], errors="coerce") | |
| d = d.dropna(subset=[c1, c2]) | |
| if d.empty: | |
| return _empty_fig("No complete near-field speech / lab-simulated WER pairs yet.") | |
| fig = go.Figure() | |
| x_pct = [_wer_pct(v) for v in d[c1]] | |
| y_pct = [_wer_pct(v) for v in d[c2]] | |
| mx = float(max((v for v in x_pct + y_pct if v is not None), default=0.0)) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[0, mx], | |
| y=[0, mx], | |
| mode="lines", | |
| name="y = x (equal WER)", | |
| line=dict(dash="dash", color="rgba(0,0,0,0.35)", width=2), | |
| hoverinfo="skip", | |
| ) | |
| ) | |
| short = d["model_id"].str.split("/").str[-1].str[:28] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x_pct, | |
| y=y_pct, | |
| mode="markers", | |
| text=short, | |
| marker=dict(size=10, opacity=0.85), | |
| hovertemplate=( | |
| "<b>%{text}</b><br>Near field speech WER: %{x:.2f}%<br>" | |
| "Lab simulated WER: %{y:.2f}%<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=dict(text=title, x=0.5, xanchor="center"), | |
| xaxis=dict(title="WER (%), Near Field Speech (lower is better)", rangemode="tozero"), | |
| yaxis=dict(title="WER (%), Lab Simulated (lower is better)", rangemode="tozero"), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| margin=dict(l=60, r=40, t=60, b=60), | |
| showlegend=True, | |
| ) | |
| return fig | |
| def plot_pareto_frontier( | |
| df: pd.DataFrame, | |
| title: str | None = None, | |
| ) -> go.Figure: | |
| """ | |
| Pareto front: X = Average WER (%, live scenarios), Y = RTFx (log scale). | |
| Frontier models are highlighted with star markers + name labels and connected | |
| by a dashed line; non-frontier models render as light dots. | |
| """ | |
| if df.empty: | |
| return _empty_fig("No leaderboard data yet.") | |
| need = {"model_id", "avg_wer", "eval_rtf"} | |
| if not need.issubset(df.columns): | |
| return _empty_fig("Missing Average WER or RTFx columns for the Pareto chart.") | |
| d = df[["model_id", "avg_wer", "eval_rtf"]].copy() | |
| d["avg_wer"] = pd.to_numeric(d["avg_wer"], errors="coerce") | |
| d["eval_rtf"] = pd.to_numeric(d["eval_rtf"], errors="coerce") | |
| d = d.dropna(subset=["avg_wer", "eval_rtf"]) | |
| d = d[d["eval_rtf"] > 0] | |
| if d.empty: | |
| return _empty_fig( | |
| "No models with Average WER and RTFx. Re-run evaluations to populate timing." | |
| ) | |
| d["wer_pct"] = d["avg_wer"].apply(lambda v: _wer_pct(float(v))) | |
| d = d.dropna(subset=["wer_pct"]).reset_index(drop=True) | |
| # Layer-1 frontier on (avg_wer, eval_rtf): lower WER + higher RTF is better. | |
| fx_raw, fy = _pareto_efficient_frontier_wer_rtf( | |
| d["avg_wer"].tolist(), | |
| d["eval_rtf"].tolist(), | |
| ) | |
| frontier_pairs = set(zip([round(w, 8) for w in fx_raw], [round(r, 8) for r in fy])) | |
| def _is_frontier(row) -> bool: | |
| return (round(float(row["avg_wer"]), 8), round(float(row["eval_rtf"]), 8)) in frontier_pairs | |
| d["_frontier"] = d.apply(_is_frontier, axis=1) | |
| frontier_df = d[d["_frontier"]].sort_values("wer_pct").reset_index(drop=True) | |
| other_df = d[~d["_frontier"]] | |
| accent = "#1f77ff" # cool blue (matches screenshot) | |
| accent_light = "rgba(31, 119, 255, 0.30)" | |
| fig = go.Figure() | |
| if not other_df.empty: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=other_df["wer_pct"], | |
| y=other_df["eval_rtf"], | |
| mode="markers", | |
| name="Other models", | |
| marker=dict(size=9, color=accent_light, line=dict(width=0, color="white")), | |
| text=other_df["model_id"], | |
| hovertemplate=( | |
| "<b>%{text}</b><br>" | |
| "Average WER: %{x:.2f}%<br>" | |
| "RTFx: %{y:.2f}×<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| if not frontier_df.empty: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frontier_df["wer_pct"], | |
| y=frontier_df["eval_rtf"], | |
| mode="lines+markers+text", | |
| name="Pareto frontier", | |
| line=dict(color=accent, width=2, dash="dash"), | |
| marker=dict(size=14, color=accent, symbol="star", line=dict(width=0, color="white")), | |
| text=frontier_df["model_id"], | |
| textposition="top center", | |
| textfont=dict(size=11, color=accent), | |
| cliponaxis=False, | |
| hovertemplate=( | |
| "<b>%{text}</b><br>" | |
| "Average WER: %{x:.2f}%<br>" | |
| "RTFx: %{y:.2f}×<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| ttl = title or "Pareto Front: Average WER vs RTFx" | |
| rtf_min = float(d["eval_rtf"].min()) | |
| rtf_max = float(d["eval_rtf"].max()) | |
| log_y_lo = np.log10(max(rtf_min * 0.85, 1e-4)) | |
| log_y_hi = np.log10(rtf_max * 1.2) | |
| fig.update_layout( | |
| title=dict(text=ttl, x=0.0, xanchor="left", font=dict(size=18)), | |
| xaxis=dict( | |
| title="Average WER (lower is better)", | |
| ticksuffix="", | |
| zeroline=False, | |
| ), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| autosize=True, | |
| margin=dict(l=70, r=40, t=70, b=60), | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", | |
| y=0.98, | |
| xanchor="right", | |
| x=0.99, | |
| bgcolor="rgba(255,255,255,0.7)", | |
| borderwidth=0, | |
| ), | |
| ) | |
| fig.update_yaxes( | |
| title="RTFx (higher is better)", | |
| type="log", | |
| zeroline=False, | |
| range=[log_y_lo, log_y_hi], | |
| tickformat=".2g", | |
| ) | |
| fig.update_xaxes( | |
| title="Average WER (lower is better)", | |
| zeroline=False, | |
| range=[0, 100], | |
| ) | |
| return fig | |
| def plot_scenario_bar_summary(df: pd.DataFrame, top_n: int = 8) -> go.Figure: | |
| """Grouped bars: one trace per model; legend toggles models.""" | |
| if df.empty: | |
| return _empty_fig("No leaderboard data yet.") | |
| cols = [c for c in HEATMAP_SCENARIO_KEYS if c in df.columns] | |
| if not cols: | |
| return _empty_fig("No core WER columns.") | |
| d = df.copy() | |
| if "avg_wer" in d.columns: | |
| d = d.sort_values("avg_wer", ascending=True, na_position="last").head(int(top_n)) | |
| else: | |
| d["_avg"] = d[cols].mean(axis=1, skipna=True) | |
| d = d.sort_values("_avg", ascending=True).head(int(top_n)) | |
| scenarios = cols | |
| x_labels = [metric_by_key(c).short if metric_by_key(c) else c for c in scenarios] | |
| fig = go.Figure() | |
| palette = qualitative.Plotly * 3 | |
| for i, (_, row) in enumerate(d.iterrows()): | |
| heights = [ | |
| _wer_pct(row[c]) if pd.notna(row[c]) else None for c in scenarios | |
| ] | |
| short = row["model_id"].split("/")[-1][:22] | |
| fig.add_trace( | |
| go.Bar( | |
| name=short, | |
| x=x_labels, | |
| y=heights, | |
| marker_color=palette[i % len(palette)], | |
| hovertemplate=f"<b>{short}</b><br>%{{x}}: %{{y:.2f}}%<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| barmode="group", | |
| title=dict( | |
| text=f"WER by scenario: top {len(d)} models by Average WER (legend toggles models)", | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| xaxis=dict(title="Scenario", tickangle=-20), | |
| yaxis=dict(title="WER (%)", rangemode="tozero"), | |
| template=_TEMPLATE, | |
| height=_FIG_HEIGHT, | |
| margin=dict(l=60, r=120, t=60, b=100), | |
| **_plotly_legend_config(), | |
| ) | |
| return fig | |
| def plot_lines_by_rank(df: pd.DataFrame, metric_keys: Sequence[str], title: str | None = None) -> go.Figure: | |
| """Deprecated: use plot_clean_vs_reverb_scatter or plot_latency_vs_wer.""" | |
| return plot_clean_vs_reverb_scatter(df, title=title or "Clean vs reverberant WER") | |