ffasr / utils_display.py
whojavumusic's picture
small blemish fixes
b60e9d6
Raw
History Blame Contribute Delete
6.63 kB
from __future__ import annotations
from dataclasses import dataclass
# These classes are for user facing column names, to avoid having to change them
# all around the code when a modif is needed
@dataclass
class ColumnContent:
name: str
type: str
def fields(raw_class):
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
# CSV keys averaged for the headline "Avg WER (%)" column (dry + realistic SNR).
AVG_WER_CORE_KEYS: tuple[str, ...] = (
"wer_anechoic_speech",
"wer_realistic_high_snr",
"wer_realistic_mid_snr",
"wer_realistic_low_snr",
)
@dataclass(frozen=True)
class AutoEvalColumn:
model = ColumnContent("Model", "markdown")
avg_wer_core = ColumnContent("Avg WER (%)", "number")
wer_anechoic = ColumnContent("Near Field Speech", "number")
wer_lab_measured = ColumnContent("Lab Measured", "number")
wer_lab_simulated = ColumnContent("Lab Simulated", "number")
wer_realistic_high_snr = ColumnContent("High SNR", "number")
wer_realistic_mid_snr = ColumnContent("Mid SNR", "number")
wer_realistic_low_snr = ColumnContent("Low SNR", "number")
wer_moving_low = ColumnContent("Moving Low SNR *", "number")
wer_moving_mid = ColumnContent("Moving Mid SNR *", "number")
wer_moving_high = ColumnContent("Moving High SNR *", "number")
eval_rtf = ColumnContent("RTFx", "number")
params_m = ColumnContent("Params (B)", "number")
# Display names of the per-condition WER benchmark columns. Used to:
# * compute Average WER from the visible benchmark columns in the leaderboard filter,
# * sanity-check column visibility logic.
SCENARIO_DISPLAY_COLS: tuple[str, ...] = (
AutoEvalColumn.wer_anechoic.name,
AutoEvalColumn.wer_lab_measured.name,
AutoEvalColumn.wer_lab_simulated.name,
AutoEvalColumn.wer_realistic_high_snr.name,
AutoEvalColumn.wer_realistic_mid_snr.name,
AutoEvalColumn.wer_realistic_low_snr.name,
AutoEvalColumn.wer_moving_low.name,
AutoEvalColumn.wer_moving_mid.name,
AutoEvalColumn.wer_moving_high.name,
)
# Leaderboard display header -> canonical CSV / analytics metric key.
SCENARIO_DISPLAY_TO_KEY: dict[str, str] = {
AutoEvalColumn.wer_anechoic.name: "wer_anechoic_speech",
AutoEvalColumn.wer_lab_measured.name: "wer_lab_measured",
AutoEvalColumn.wer_lab_simulated.name: "wer_lab_simulated",
AutoEvalColumn.wer_realistic_high_snr.name: "wer_realistic_high_snr",
AutoEvalColumn.wer_realistic_mid_snr.name: "wer_realistic_mid_snr",
AutoEvalColumn.wer_realistic_low_snr.name: "wer_realistic_low_snr",
AutoEvalColumn.wer_moving_low.name: "wer_moving_low",
AutoEvalColumn.wer_moving_mid.name: "wer_moving_mid",
AutoEvalColumn.wer_moving_high.name: "wer_moving_high",
}
# Fixed pixel widths so the markdown Model column cannot expand and hide WER cols.
_COLUMN_WIDTHS_PX: dict[str, int] = {
AutoEvalColumn.model.name: 200,
AutoEvalColumn.avg_wer_core.name: 88,
AutoEvalColumn.wer_anechoic.name: 96,
AutoEvalColumn.wer_lab_measured.name: 96,
AutoEvalColumn.wer_lab_simulated.name: 96,
AutoEvalColumn.wer_realistic_high_snr.name: 72,
AutoEvalColumn.wer_realistic_mid_snr.name: 68,
AutoEvalColumn.wer_realistic_low_snr.name: 68,
AutoEvalColumn.wer_moving_low.name: 112,
AutoEvalColumn.wer_moving_mid.name: 112,
AutoEvalColumn.wer_moving_high.name: 116,
AutoEvalColumn.eval_rtf.name: 58,
AutoEvalColumn.params_m.name: 62,
}
def column_widths_for(columns: list[str]) -> list[str]:
"""Gradio Dataframe column_widths for the visible column set."""
return [f"{_COLUMN_WIDTHS_PX.get(c, 80)}px" for c in columns]
# Custom URL mappings for models not on HuggingFace or with special pages
_CUSTOM_LINKS = {
"elevenlabs": "https://elevenlabs.io/speech-to-text",
"assemblyai": "https://www.assemblyai.com/docs/getting-started/universal-3-pro",
"aquavoice": "https://aquavoice.com/blog/introducing-avalon",
"zoom": "https://developers.zoom.us/docs/ai-services/scribe/",
"revai": "https://docs.rev.ai/api/asynchronous/get-started/",
"speechmatics": "https://docs.speechmatics.com/speech-to-text/batch/quickstart",
"google/chirp": "https://cloud.google.com/blog/products/ai-machine-learning/bringing-power-large-models-google-clouds-speech-api",
"google/chirp_2": "https://docs.cloud.google.com/speech-to-text/docs/models/chirp-2",
"google/chirp_3": "https://docs.cloud.google.com/speech-to-text/docs/models/chirp-3",
"trt-llm": "https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper",
"faster-whisper": "https://github.com/guillaumekln/faster-whisper",
"Whisper.cpp": "https://github.com/ggerganov/whisper.cpp",
"WhisperKit": "https://github.com/argmaxinc/WhisperKit",
"WhisperMLX": "https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc",
}
def format_wer_percent(v, *, ndigits: int = 2) -> str | float:
"""Fractional WER (0–1) → display percent, or ``'NA'``."""
if v is None:
return "NA"
if isinstance(v, str) and not str(v).strip():
return "NA"
try:
f = float(v)
except (TypeError, ValueError):
return "NA"
if f != f: # NaN
return "NA"
return round(f * 100.0, ndigits)
def make_clickable_model(model_name):
parts = model_name.split("/")
# Check exact match first (e.g. "google/chirp_2")
if model_name in _CUSTOM_LINKS:
link = _CUSTOM_LINKS[model_name]
# Check prefix match (e.g. "elevenlabs/scribe_v2" -> "elevenlabs")
elif parts[0] in _CUSTOM_LINKS:
link = _CUSTOM_LINKS[parts[0]]
else:
link = f"https://huggingface.co/{model_name}"
return (
f'<a target="_blank" href="{link}" '
f'style="color: var(--link-text-color); text-decoration: underline;'
f'text-decoration-style: dotted;">{model_name}</a>'
)
def styled_error(error):
return f"<p style='color: red; font-size: 18px; text-align: center;'>{error}</p>"
def styled_warning(warn):
return f"<p style='color: orange; font-size: 18px; text-align: center;'>{warn}</p>"
def styled_message(message):
return f"<p style='color: green; font-size: 18px; text-align: center;'>{message}</p>"
def model_id_from_leaderboard_cell(cell) -> str:
"""Recover Hugging Face model id from the markdown Model column."""
if not isinstance(cell, str):
return ""
cell = cell.strip()
if "</a>" in cell:
import re
m = re.search(r">([^<]+)</a>\s*$", cell)
if m:
return m.group(1).strip()
return cell