""" Streamlit app to inspect MMS evaluation results. Run from the finetuning/mms directory: streamlit run inspect_results.py Or point at a different results directory: streamlit run inspect_results.py -- --results_dir /path/to/results """ import difflib import html as html_lib import os import re from pathlib import Path import pandas as pd import streamlit as st from huggingface_hub import hf_hub_download def check_password(): if "authenticated" not in st.session_state: st.session_state.authenticated = False if not st.session_state.authenticated: st.markdown("## πŸ”’ Login required") password = st.text_input("Enter password", type="password") if st.button("Login"): try: correct = os.environ.get("APP_PASSWORD") except Exception as e: st.error(f"Error: {e}") st.stop() if password == correct: st.session_state.authenticated = True st.rerun() else: st.error("Incorrect password") st.stop() DATASETS = ["clb_belize", "hf_fongbe", "clb", "fongbe", "alffa"] SPLITS = ["validation", "test"] DECODES = ["greedy", "lm", "beam"] _FNAME_RE = re.compile( rf"^eval_(.+?)_({'|'.join(DATASETS)})_({'|'.join(SPLITS)})_({'|'.join(DECODES)})\.csv$" ) CSS = """ """ def wer_badge(w: float, label: str = "WER") -> str: cls = "badge-good" if w < 0.3 else ("badge-ok" if w < 0.7 else "badge-bad") return f'{label} {w:.2f}' def cer_badge(c: float, label: str = "CER") -> str: cls = "badge-good" if c < 0.15 else ("badge-ok" if c < 0.4 else "badge-bad") return f'{label} {c:.2f}' _COLORS = { "correct": ("#1e7e44", "#edfaf2"), "close": ("#916a00", "#fff8e6"), "wrong": ("#b91c1c", "#fff1f1"), "insert": ("#b91c1c", "#ffeaea"), } def _get_slots(ref: str, pred: str) -> list[tuple]: """Return list of (ref_word, pred_word, kind) alignment slots.""" ref_words = str(ref).split() pred_words = str(pred).split() slots = [] for op, i1, i2, j1, j2 in difflib.SequenceMatcher(None, ref_words, pred_words, autojunk=False).get_opcodes(): if op == "equal": for rw, pw in zip(ref_words[i1:i2], pred_words[j1:j2]): slots.append((rw, pw, "correct")) elif op == "replace": r_chunk, p_chunk = ref_words[i1:i2], pred_words[j1:j2] for rw, pw in zip(r_chunk, p_chunk): sim = difflib.SequenceMatcher(None, rw, pw).ratio() slots.append((rw, pw, "close" if sim >= 0.6 else "wrong")) for w in r_chunk[len(p_chunk):]: slots.append((w, "", "wrong")) for w in p_chunk[len(r_chunk):]: slots.append(("", w, "insert")) elif op == "delete": for w in ref_words[i1:i2]: slots.append((w, "", "wrong")) elif op == "insert": for w in pred_words[j1:j2]: slots.append(("", w, "insert")) return slots def _word_html(w: str, kind: str, is_ref: bool, highlight: bool) -> str: if not w: return 'Β·' if (highlight and not is_ref) else "" esc = html_lib.escape(w) if not highlight or is_ref: return esc color, bg = _COLORS[kind] extra = ";text-decoration:underline dotted" if kind == "insert" else "" return f'{esc}' def render_aligned_cols(ref: str, pred: str, highlight: bool) -> str: """Two-column layout where each segment row covers one alignment run. Segments break at equal↔error transitions (and every MAX_EQUAL words within a long equal run) so both columns wrap at the same boundaries. """ MAX_EQUAL = 8 slots = _get_slots(ref, pred) # Group slots into segments segs: list[list[tuple]] = [] cur: list[tuple] = [] cur_eq: bool | None = None for slot in slots: is_eq = slot[2] == "correct" if cur_eq is not None and (is_eq != cur_eq or (is_eq and len(cur) >= MAX_EQUAL)): segs.append(cur) cur = [] cur.append(slot) cur_eq = is_eq if cur: segs.append(cur) rows = [] for seg in segs: ref_parts = [_word_html(rw, k, True, highlight) for rw, pw, k in seg] pred_parts = [_word_html(pw, k, False, highlight) for rw, pw, k in seg] ref_html = " ".join(p for p in ref_parts if p) pred_html = " ".join(p for p in pred_parts if p) if not ref_html and not pred_html: continue rows.append( f'
' f'
{ref_html}
' f'
{pred_html}
' f'
' ) return "\n".join(rows) def render_flat_cols(ref: str, pred: str, highlight: bool) -> str: """Plain side-by-side columns β€” optionally with word colours on prediction only.""" ref_html = html_lib.escape(str(ref)) if highlight: slots = _get_slots(ref, pred) pred_html = " ".join(_word_html(pw, k, False, True) for rw, pw, k in slots if pw) else: pred_html = html_lib.escape(str(pred)) return ( '
' f'
{ref_html}
' f'
{pred_html}
' '
' ) def zoom_pred_html(ref: str, pred: str, highlight: bool) -> str: """Flat coloured prediction text for the zoom dialog.""" if not highlight: return f'
{html_lib.escape(str(pred))}
' parts = [] for _, pw, kind in _get_slots(ref, pred): parts.append(_word_html(pw, kind, False, True)) return f'
{" ".join(p for p in parts if p)}
' def parse_result_files(results_dir: str) -> list[dict]: entries = [] if not os.path.isdir(results_dir): return entries for fname in sorted(os.listdir(results_dir)): m = _FNAME_RE.match(fname) if m: entries.append({ "model": m.group(1), "dataset": m.group(2), "split": m.group(3), "decode": m.group(4), "path": os.path.join(results_dir, fname), }) return entries def load_all_dfs(entries: list[dict]) -> dict[tuple, pd.DataFrame]: """Load every result CSV into a dict keyed by (model, dataset, split, decode).""" dfs = {} for e in entries: df = pd.read_csv(e["path"]) df["path"] = df["path"].apply(resolve_audio_path) dfs[(e["model"], e["dataset"], e["split"], e["decode"])] = df return dfs _APP_DIR = os.path.dirname(os.path.abspath(__file__)) def get_results_dir() -> str: return os.path.join(_APP_DIR, "results") def resolve_audio_path(path: str) -> str: """Return an absolute path, resolving relative paths against the app directory.""" if os.path.isabs(path): return path return os.path.join(_APP_DIR, path) # ── HF dataset audio fallback ────────────────────────────────────────────── # Maps the top-level folder name (as it appears in CSV paths) to the HF repo. # Strip the folder prefix to get the path inside the repo. _HF_AUDIO_DATASETS: dict[str, str] = { "audio_chunks": "clb-benin/clb_data", "fongbe_speech_audio_files": "clb-benin/fongbe-data", } def _hf_token() -> str | None: """Return HF token from Streamlit secrets or HF_TOKEN env var, or None.""" try: return st.secrets.get("HF_TOKEN") # type: ignore[attr-defined] except Exception: pass return os.environ.get("HF_TOKEN") @st.cache_data(show_spinner=False) def get_audio_bytes(audio_path: str) -> bytes | None: """Load audio bytes from a local file; fall back to HF dataset if not found.""" if os.path.isfile(audio_path): with open(audio_path, "rb") as f: return f.read() # Derive path relative to the app dir so we can match the HF mapping. try: rel = os.path.relpath(audio_path, _APP_DIR) except ValueError: return None parts = Path(rel).parts # e.g. ("audio_chunks", "subfolder", "chunk_0000.wav") if parts and parts[0] in _HF_AUDIO_DATASETS: repo_id = _HF_AUDIO_DATASETS[parts[0]] hf_file = "/".join(parts[1:]) # strip the local top-level folder try: local = hf_hub_download( repo_id=repo_id, filename=hf_file, repo_type="dataset", token=_hf_token(), ) with open(local, "rb") as f: return f.read() except Exception: pass return None # ── Dialog for zoomed example ────────────────────────────────────────────── @st.dialog("Example detail", width="large") def show_example_detail(audio_path: str, all_dfs: dict, highlight: bool): audio_bytes = get_audio_bytes(audio_path) if audio_bytes: st.audio(audio_bytes, format="audio/wav") st.caption(os.path.join(os.path.basename(os.path.dirname(audio_path)), os.path.basename(audio_path))) st.markdown("
", unsafe_allow_html=True) ref_text = None ref_shown = False for (model, dataset, split, decode), df in all_dfs.items(): match = df[df["path"] == audio_path] if match.empty: continue row = match.iloc[0] if not ref_shown: ref_text = str(row["ref"]) st.markdown( f'
' f'
Reference
' f'
{html_lib.escape(ref_text)}
' f'
', unsafe_allow_html=True, ) ref_shown = True pred_block = zoom_pred_html(ref_text or "", str(row["pred"]), highlight) st.markdown( f'
' f'
{model}  Β·  {dataset}  Β·  {split}  Β·  {decode}
' f'{pred_block}' f'
{wer_badge(row["wer"])}{cer_badge(row["cer"])}
' f'
', unsafe_allow_html=True, ) @st.dialog("Normalized text", width="large") def show_normalized_detail(ref_norm: str, pred_norm: str, wer_n: float, cer_n: float, hl: bool): st.markdown( f'
' f'{wer_badge(wer_n, "WER~")}{cer_badge(cer_n, "CER~")}' f'
', unsafe_allow_html=True, ) st.markdown("
", unsafe_allow_html=True) st.markdown( f'
' f'
Reference (normalized)
' f'
{html_lib.escape(ref_norm)}
' f'
', unsafe_allow_html=True, ) pred_block = zoom_pred_html(ref_norm, pred_norm, hl) st.markdown( f'
' f'
Prediction (normalized)
' f'{pred_block}' f'
', unsafe_allow_html=True, ) # ── App ──────────────────────────────────────────────────────────────────── st.set_page_config(page_title="MMS Evaluation Inspector", layout="wide") st.markdown(CSS, unsafe_allow_html=True) check_password() results_dir = get_results_dir() entries = parse_result_files(results_dir) if not entries: st.error(f"No evaluation CSVs found in `{results_dir}`. Run `evaluate_mms_model.py` first.") st.stop() all_dfs = load_all_dfs(entries) models = sorted(set(e["model"] for e in entries)) datasets = sorted(set(e["dataset"] for e in entries)) splits = sorted(set(e["split"] for e in entries)) decodes = sorted(set(e["decode"] for e in entries)) has_norm = any("wer_normalized" in df.columns for df in all_dfs.values()) # ── Sidebar ──────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("### Run") sel_model = st.selectbox("Model", models) sel_dataset = st.selectbox("Dataset", datasets) sel_split = st.selectbox("Split", splits) sel_decode = st.selectbox("Decoding", decodes) st.markdown("---") st.markdown("### Display") sort_opts = ["original order", "WER ↑", "WER ↓", "CER ↑", "CER ↓"] if has_norm: sort_opts += ["WER~ ↑", "WER~ ↓", "CER~ ↑", "CER~ ↓"] sort_by = st.selectbox("Sort by", sort_opts) max_examples = st.slider("Examples", min_value=5, max_value=500, value=50, step=5) # ── Find file ────────────────────────────────────────────────────────────── key = (sel_model, sel_dataset, sel_split, sel_decode) df = all_dfs.get(key) if df is None: st.warning("No results file found for this combination.") st.stop() # ── Header ───────────────────────────────────────────────────────────────── st.markdown("## MMS Evaluation") st.markdown( f"" f"{sel_model}  Β·  {sel_dataset}  Β·  {sel_split}  Β·  {sel_decode}" f"", unsafe_allow_html=True, ) st.markdown("
", unsafe_allow_html=True) if has_norm: col1, col2, col3, col4, col5 = st.columns(5) col1.metric("Examples", len(df)) col2.metric("Avg WER", f"{df['wer'].mean():.3f}") col3.metric("Avg CER", f"{df['cer'].mean():.3f}") col4.metric("Avg WER~", f"{df['wer_normalized'].mean():.3f}") col5.metric("Avg CER~", f"{df['cer_normalized'].mean():.3f}") else: col1, col2, col3 = st.columns(3) col1.metric("Examples", len(df)) col2.metric("Avg WER", f"{df['wer'].mean():.3f}") col3.metric("Avg CER", f"{df['cer'].mean():.3f}") st.markdown("
", unsafe_allow_html=True) # ── Sort ─────────────────────────────────────────────────────────────────── _sort_map = { "WER ↑": ("wer", True), "WER ↓": ("wer", False), "CER ↑": ("cer", True), "CER ↓": ("cer", False), "WER~ ↑": ("wer_normalized", True), "WER~ ↓": ("wer_normalized", False), "CER~ ↑": ("cer_normalized", True), "CER~ ↓": ("cer_normalized", False), } if sort_by in _sort_map: col, asc = _sort_map[sort_by] df = df.sort_values(col, ascending=asc).reset_index(drop=True) # ── Beam results for inline comparison ──────────────────────────────────── beam_df = ( all_dfs.get((sel_model, sel_dataset, sel_split, "beam")) if sel_decode != "beam" else None ) # ── Examples ─────────────────────────────────────────────────────────────── for i, row in df.head(max_examples).iterrows(): audio_path = row["path"] with st.container(border=True): # Header: filename + badges norm_badges = ( f' {wer_badge(row["wer_normalized"], "WER~")}{cer_badge(row["cer_normalized"], "CER~")}' if has_norm else "" ) st.markdown( f'
' f'{os.path.join(os.path.basename(os.path.dirname(audio_path)), os.path.basename(audio_path))}' f'
{wer_badge(row["wer"])}{cer_badge(row["cer"])}{norm_badges}
' f'
', unsafe_allow_html=True, ) left, right = st.columns([1, 2], gap="large") with left: audio_bytes = get_audio_bytes(audio_path) if audio_bytes: st.audio(audio_bytes, format="audio/wav") else: st.caption("audio not found") tc1, tc2 = st.columns(2) hl = tc1.toggle("Highlight", key=f"hl_{i}", value=False) al = tc2.toggle("Align", key=f"al_{i}", value=False) if st.button("Compare all runs", key=f"zoom_{i}", use_container_width=True): show_example_detail(audio_path, all_dfs, hl) if has_norm and st.button("Normalized view", key=f"norm_{i}", use_container_width=True): show_normalized_detail( str(row["ref_norm"]), str(row["pred_norm"]), float(row["wer_normalized"]), float(row["cer_normalized"]), hl, ) with right: labels = ( '
' '
Reference
' '
Prediction
' '
' ) body = render_aligned_cols(str(row['ref']), str(row['pred']), hl) if al \ else render_flat_cols(str(row['ref']), str(row['pred']), hl) st.markdown(f'{labels}{body}', unsafe_allow_html=True) if beam_df is not None: beam_match = beam_df[beam_df["path"] == audio_path] if not beam_match.empty: beam_row = beam_match.iloc[0] beam_pred_block = zoom_pred_html(str(row["ref"]), str(beam_row["pred"]), hl) beam_scores = ( f'
' f'{wer_badge(beam_row["wer"])}{cer_badge(beam_row["cer"])}' f'
' ) st.markdown( f'
' f'
' f'
Beam
' f'{beam_pred_block}' f'{beam_scores}' f'
', unsafe_allow_html=True, )