eval-analysis / src /streamlit_app.py
simonevanbruggen's picture
Update src/streamlit_app.py
89c259f verified
Raw
History Blame Contribute Delete
24.4 kB
"""
Streamlit app to inspect MMS evaluation results.
Run from the finetuning/mms directory:
streamlit run inspect_results.py
Or point at a different results directory:
streamlit run inspect_results.py -- --results_dir /path/to/results
"""
import difflib
import html as html_lib
import os
import re
from pathlib import Path
import pandas as pd
import streamlit as st
from huggingface_hub import hf_hub_download
def check_password():
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
st.markdown("## πŸ”’ Login required")
password = st.text_input("Enter password", type="password")
if st.button("Login"):
try:
correct = os.environ.get("APP_PASSWORD")
except Exception as e:
st.error(f"Error: {e}")
st.stop()
if password == correct:
st.session_state.authenticated = True
st.rerun()
else:
st.error("Incorrect password")
st.stop()
DATASETS = ["clb_belize", "hf_fongbe", "clb", "fongbe", "alffa"]
SPLITS = ["validation", "test"]
DECODES = ["greedy", "lm", "beam"]
_FNAME_RE = re.compile(
rf"^eval_(.+?)_({'|'.join(DATASETS)})_({'|'.join(SPLITS)})_({'|'.join(DECODES)})\.csv$"
)
CSS = """
<style>
/* ── Backgrounds ── */
[data-testid="stAppViewContainer"],
[data-testid="stAppViewContainer"] > .main,
[data-testid="stHeader"] { background-color: #f8f8f6; }
/* ── Typography β€” avoid overriding icon fonts by targeting text nodes only ── */
body, p, h1, h2, h3, h4, h5, h6,
[data-testid="stMarkdownContainer"],
[data-testid="stText"],
[data-testid="stMetricLabel"],
[data-testid="stMetricValue"],
[data-testid="stSidebar"] label,
.ex-card, .zoom-row, .badge, .ex-filename {
font-family: system-ui, -apple-system, "Segoe UI", sans-serif;
color: #1a1a1a;
}
/* ── Sidebar ── */
[data-testid="stSidebar"] {
background-color: #f0efed;
border-right: 1px solid #e5e3e0;
}
[data-testid="stSidebar"] label {
font-size: 0.68rem !important;
font-weight: 600 !important;
letter-spacing: 0.07em;
text-transform: uppercase;
color: #888 !important;
}
/* ── All select boxes (sidebar + main) ── */
[data-baseweb="select"] > div {
background-color: #fff !important;
border: 1px solid #ddd !important;
border-radius: 8px !important;
color: #1a1a1a !important;
}
[data-baseweb="select"] span { color: #1a1a1a !important; }
/* Dropdown menu popup */
[data-baseweb="popover"],
[data-baseweb="menu"] {
background-color: #fff !important;
border: 1px solid #e5e5e5 !important;
border-radius: 8px !important;
box-shadow: 0 4px 16px rgba(0,0,0,0.08) !important;
}
[data-baseweb="menu"] li,
[data-baseweb="menu"] [role="option"] {
background-color: #fff !important;
color: #1a1a1a !important;
}
[data-baseweb="menu"] li:hover,
[data-baseweb="menu"] [role="option"]:hover {
background-color: #f5f5f3 !important;
}
/* ── Buttons ── */
[data-testid="stButton"] button {
background-color: #fff !important;
color: #1a1a1a !important;
border: 1px solid #ddd !important;
border-radius: 8px !important;
font-size: 0.8rem !important;
font-weight: 500 !important;
padding: 0.3rem 0.8rem !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important;
}
[data-testid="stButton"] button:hover {
background-color: #f5f5f3 !important;
border-color: #ccc !important;
}
/* ── Metrics ── */
[data-testid="stMetric"] {
background: #fff;
border: 1px solid #ebebeb;
border-radius: 10px;
padding: 0.9rem 1.1rem;
box-shadow: 0 1px 4px rgba(0,0,0,0.04);
}
[data-testid="stMetricLabel"] > div {
font-size: 0.68rem !important;
font-weight: 600 !important;
letter-spacing: 0.08em;
text-transform: uppercase;
color: #aaa !important;
}
[data-testid="stMetricValue"] > div {
font-size: 1.6rem !important;
font-weight: 500 !important;
color: #1a1a1a !important;
}
/* ── Example container (st.container border=True) ── */
[data-testid="stVerticalBlockBorderWrapper"] {
background: #fff !important;
border: 1px solid #e8e5e0 !important;
border-radius: 12px !important;
box-shadow: 0 1px 5px rgba(0,0,0,0.05) !important;
padding: 0.1rem 0.25rem !important;
margin-bottom: 0.75rem !important;
transition: box-shadow 0.15s;
}
[data-testid="stVerticalBlockBorderWrapper"]:hover {
box-shadow: 0 3px 12px rgba(0,0,0,0.09) !important;
}
.ex-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 0.1rem 0 0.5rem 0;
border-bottom: 1px solid #f0ede8;
margin-bottom: 0.6rem;
}
.ex-card-cols {
display: flex;
gap: 1.25rem;
}
.ex-col {
flex: 1;
min-width: 0;
padding-right: 0.5rem;
}
.ex-col + .ex-col {
border-left: 1px solid #f0f0f0;
padding-left: 1rem;
padding-right: 0;
}
.ex-label {
font-size: 0.65rem;
font-weight: 600;
letter-spacing: 0.09em;
text-transform: uppercase;
color: #aaa;
margin-bottom: 0.2rem;
}
.ex-text {
font-size: 0.95rem;
line-height: 1.6;
color: #1a1a1a;
margin-bottom: 0;
}
.ex-scores { display: flex; gap: 0.5rem; margin-top: 0.75rem; }
.badge {
font-size: 0.7rem;
font-weight: 500;
padding: 0.15rem 0.55rem;
border-radius: 999px;
background: #f3f3f3;
color: #555;
border: 1px solid #e8e8e8;
}
.badge-bad { background: #fff1f1; color: #c0392b; border-color: #fad7d7; }
.badge-ok { background: #fff8ed; color: #b07d1a; border-color: #f5e0b0; }
.badge-good { background: #f0faf3; color: #1e7e44; border-color: #c3e8cf; }
.ex-filename { font-family: monospace; font-size: 0.65rem; color: #bbb; margin-top: 0.4rem; }
/* ── Dialog ── */
[data-testid="stDialog"] [data-testid="stVerticalBlock"] { gap: 0.5rem; }
.zoom-ref {
background: #f5f3ff;
border: 1px solid #ddd6fe;
border-left: 4px solid #7c6ff7;
border-radius: 8px;
padding: 0.9rem 1rem;
margin-bottom: 0.75rem;
}
.zoom-ref .ex-label { color: #7c6ff7; margin-bottom: 0.2rem; }
.zoom-ref .ex-text { font-size: 1rem; font-weight: 500; color: #1a1a1a; margin-bottom: 0; }
.zoom-row {
background: #fafafa;
border: 1px solid #ebebeb;
border-radius: 8px;
padding: 0.75rem 1rem;
margin-bottom: 0.4rem;
}
.zoom-row .ex-label { margin-bottom: 0.1rem; }
.zoom-row .ex-text { margin-bottom: 0; font-size: 0.9rem; }
hr { border: none; border-top: 1px solid #e8e8e8; margin: 0.75rem 0; }
</style>
"""
def wer_badge(w: float, label: str = "WER") -> str:
cls = "badge-good" if w < 0.3 else ("badge-ok" if w < 0.7 else "badge-bad")
return f'<span class="badge {cls}">{label} {w:.2f}</span>'
def cer_badge(c: float, label: str = "CER") -> str:
cls = "badge-good" if c < 0.15 else ("badge-ok" if c < 0.4 else "badge-bad")
return f'<span class="badge {cls}">{label} {c:.2f}</span>'
_COLORS = {
"correct": ("#1e7e44", "#edfaf2"),
"close": ("#916a00", "#fff8e6"),
"wrong": ("#b91c1c", "#fff1f1"),
"insert": ("#b91c1c", "#ffeaea"),
}
def _get_slots(ref: str, pred: str) -> list[tuple]:
"""Return list of (ref_word, pred_word, kind) alignment slots."""
ref_words = str(ref).split()
pred_words = str(pred).split()
slots = []
for op, i1, i2, j1, j2 in difflib.SequenceMatcher(None, ref_words, pred_words, autojunk=False).get_opcodes():
if op == "equal":
for rw, pw in zip(ref_words[i1:i2], pred_words[j1:j2]):
slots.append((rw, pw, "correct"))
elif op == "replace":
r_chunk, p_chunk = ref_words[i1:i2], pred_words[j1:j2]
for rw, pw in zip(r_chunk, p_chunk):
sim = difflib.SequenceMatcher(None, rw, pw).ratio()
slots.append((rw, pw, "close" if sim >= 0.6 else "wrong"))
for w in r_chunk[len(p_chunk):]:
slots.append((w, "", "wrong"))
for w in p_chunk[len(r_chunk):]:
slots.append(("", w, "insert"))
elif op == "delete":
for w in ref_words[i1:i2]:
slots.append((w, "", "wrong"))
elif op == "insert":
for w in pred_words[j1:j2]:
slots.append(("", w, "insert"))
return slots
def _word_html(w: str, kind: str, is_ref: bool, highlight: bool) -> str:
if not w:
return '<span style="opacity:0.22">Β·</span>' if (highlight and not is_ref) else ""
esc = html_lib.escape(w)
if not highlight or is_ref:
return esc
color, bg = _COLORS[kind]
extra = ";text-decoration:underline dotted" if kind == "insert" else ""
return f'<span style="border-radius:3px;padding:1px 3px;color:{color};background:{bg}{extra}">{esc}</span>'
def render_aligned_cols(ref: str, pred: str, highlight: bool) -> str:
"""Two-column layout where each segment row covers one alignment run.
Segments break at equal↔error transitions (and every MAX_EQUAL words within
a long equal run) so both columns wrap at the same boundaries.
"""
MAX_EQUAL = 8
slots = _get_slots(ref, pred)
# Group slots into segments
segs: list[list[tuple]] = []
cur: list[tuple] = []
cur_eq: bool | None = None
for slot in slots:
is_eq = slot[2] == "correct"
if cur_eq is not None and (is_eq != cur_eq or (is_eq and len(cur) >= MAX_EQUAL)):
segs.append(cur)
cur = []
cur.append(slot)
cur_eq = is_eq
if cur:
segs.append(cur)
rows = []
for seg in segs:
ref_parts = [_word_html(rw, k, True, highlight) for rw, pw, k in seg]
pred_parts = [_word_html(pw, k, False, highlight) for rw, pw, k in seg]
ref_html = " ".join(p for p in ref_parts if p)
pred_html = " ".join(p for p in pred_parts if p)
if not ref_html and not pred_html:
continue
rows.append(
f'<div class="ex-card-cols" style="margin-bottom:0.15rem;align-items:baseline">'
f'<div class="ex-col"><span class="ex-text" style="margin:0">{ref_html}</span></div>'
f'<div class="ex-col"><span class="ex-text" style="margin:0">{pred_html}</span></div>'
f'</div>'
)
return "\n".join(rows)
def render_flat_cols(ref: str, pred: str, highlight: bool) -> str:
"""Plain side-by-side columns β€” optionally with word colours on prediction only."""
ref_html = html_lib.escape(str(ref))
if highlight:
slots = _get_slots(ref, pred)
pred_html = " ".join(_word_html(pw, k, False, True) for rw, pw, k in slots if pw)
else:
pred_html = html_lib.escape(str(pred))
return (
'<div class="ex-card-cols">'
f'<div class="ex-col"><span class="ex-text">{ref_html}</span></div>'
f'<div class="ex-col"><span class="ex-text">{pred_html}</span></div>'
'</div>'
)
def zoom_pred_html(ref: str, pred: str, highlight: bool) -> str:
"""Flat coloured prediction text for the zoom dialog."""
if not highlight:
return f'<div class="ex-text">{html_lib.escape(str(pred))}</div>'
parts = []
for _, pw, kind in _get_slots(ref, pred):
parts.append(_word_html(pw, kind, False, True))
return f'<div class="ex-text">{" ".join(p for p in parts if p)}</div>'
def parse_result_files(results_dir: str) -> list[dict]:
entries = []
if not os.path.isdir(results_dir):
return entries
for fname in sorted(os.listdir(results_dir)):
m = _FNAME_RE.match(fname)
if m:
entries.append({
"model": m.group(1),
"dataset": m.group(2),
"split": m.group(3),
"decode": m.group(4),
"path": os.path.join(results_dir, fname),
})
return entries
def load_all_dfs(entries: list[dict]) -> dict[tuple, pd.DataFrame]:
"""Load every result CSV into a dict keyed by (model, dataset, split, decode)."""
dfs = {}
for e in entries:
df = pd.read_csv(e["path"])
df["path"] = df["path"].apply(resolve_audio_path)
dfs[(e["model"], e["dataset"], e["split"], e["decode"])] = df
return dfs
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
def get_results_dir() -> str:
return os.path.join(_APP_DIR, "results")
def resolve_audio_path(path: str) -> str:
"""Return an absolute path, resolving relative paths against the app directory."""
if os.path.isabs(path):
return path
return os.path.join(_APP_DIR, path)
# ── HF dataset audio fallback ──────────────────────────────────────────────
# Maps the top-level folder name (as it appears in CSV paths) to the HF repo.
# Strip the folder prefix to get the path inside the repo.
_HF_AUDIO_DATASETS: dict[str, str] = {
"audio_chunks": "clb-benin/clb_data",
"fongbe_speech_audio_files": "clb-benin/fongbe-data",
}
def _hf_token() -> str | None:
"""Return HF token from Streamlit secrets or HF_TOKEN env var, or None."""
try:
return st.secrets.get("HF_TOKEN") # type: ignore[attr-defined]
except Exception:
pass
return os.environ.get("HF_TOKEN")
@st.cache_data(show_spinner=False)
def get_audio_bytes(audio_path: str) -> bytes | None:
"""Load audio bytes from a local file; fall back to HF dataset if not found."""
if os.path.isfile(audio_path):
with open(audio_path, "rb") as f:
return f.read()
# Derive path relative to the app dir so we can match the HF mapping.
try:
rel = os.path.relpath(audio_path, _APP_DIR)
except ValueError:
return None
parts = Path(rel).parts # e.g. ("audio_chunks", "subfolder", "chunk_0000.wav")
if parts and parts[0] in _HF_AUDIO_DATASETS:
repo_id = _HF_AUDIO_DATASETS[parts[0]]
hf_file = "/".join(parts[1:]) # strip the local top-level folder
try:
local = hf_hub_download(
repo_id=repo_id,
filename=hf_file,
repo_type="dataset",
token=_hf_token(),
)
with open(local, "rb") as f:
return f.read()
except Exception:
pass
return None
# ── Dialog for zoomed example ──────────────────────────────────────────────
@st.dialog("Example detail", width="large")
def show_example_detail(audio_path: str, all_dfs: dict, highlight: bool):
audio_bytes = get_audio_bytes(audio_path)
if audio_bytes:
st.audio(audio_bytes, format="audio/wav")
st.caption(os.path.join(os.path.basename(os.path.dirname(audio_path)), os.path.basename(audio_path)))
st.markdown("<hr>", unsafe_allow_html=True)
ref_text = None
ref_shown = False
for (model, dataset, split, decode), df in all_dfs.items():
match = df[df["path"] == audio_path]
if match.empty:
continue
row = match.iloc[0]
if not ref_shown:
ref_text = str(row["ref"])
st.markdown(
f'<div class="zoom-ref">'
f'<div class="ex-label">Reference</div>'
f'<div class="ex-text">{html_lib.escape(ref_text)}</div>'
f'</div>',
unsafe_allow_html=True,
)
ref_shown = True
pred_block = zoom_pred_html(ref_text or "", str(row["pred"]), highlight)
st.markdown(
f'<div class="zoom-row">'
f'<div class="ex-label">{model} &nbsp;Β·&nbsp; {dataset} &nbsp;Β·&nbsp; {split} &nbsp;Β·&nbsp; {decode}</div>'
f'{pred_block}'
f'<div class="ex-scores">{wer_badge(row["wer"])}{cer_badge(row["cer"])}</div>'
f'</div>',
unsafe_allow_html=True,
)
@st.dialog("Normalized text", width="large")
def show_normalized_detail(ref_norm: str, pred_norm: str, wer_n: float, cer_n: float, hl: bool):
st.markdown(
f'<div class="ex-scores" style="margin-bottom:0.75rem">'
f'{wer_badge(wer_n, "WER~")}{cer_badge(cer_n, "CER~")}'
f'</div>',
unsafe_allow_html=True,
)
st.markdown("<hr>", unsafe_allow_html=True)
st.markdown(
f'<div class="zoom-ref">'
f'<div class="ex-label">Reference (normalized)</div>'
f'<div class="ex-text">{html_lib.escape(ref_norm)}</div>'
f'</div>',
unsafe_allow_html=True,
)
pred_block = zoom_pred_html(ref_norm, pred_norm, hl)
st.markdown(
f'<div class="zoom-row">'
f'<div class="ex-label">Prediction (normalized)</div>'
f'{pred_block}'
f'</div>',
unsafe_allow_html=True,
)
# ── App ────────────────────────────────────────────────────────────────────
st.set_page_config(page_title="MMS Evaluation Inspector", layout="wide")
st.markdown(CSS, unsafe_allow_html=True)
check_password()
results_dir = get_results_dir()
entries = parse_result_files(results_dir)
if not entries:
st.error(f"No evaluation CSVs found in `{results_dir}`. Run `evaluate_mms_model.py` first.")
st.stop()
all_dfs = load_all_dfs(entries)
models = sorted(set(e["model"] for e in entries))
datasets = sorted(set(e["dataset"] for e in entries))
splits = sorted(set(e["split"] for e in entries))
decodes = sorted(set(e["decode"] for e in entries))
has_norm = any("wer_normalized" in df.columns for df in all_dfs.values())
# ── Sidebar ────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("### Run")
sel_model = st.selectbox("Model", models)
sel_dataset = st.selectbox("Dataset", datasets)
sel_split = st.selectbox("Split", splits)
sel_decode = st.selectbox("Decoding", decodes)
st.markdown("---")
st.markdown("### Display")
sort_opts = ["original order", "WER ↑", "WER ↓", "CER ↑", "CER ↓"]
if has_norm:
sort_opts += ["WER~ ↑", "WER~ ↓", "CER~ ↑", "CER~ ↓"]
sort_by = st.selectbox("Sort by", sort_opts)
max_examples = st.slider("Examples", min_value=5, max_value=500, value=50, step=5)
# ── Find file ──────────────────────────────────────────────────────────────
key = (sel_model, sel_dataset, sel_split, sel_decode)
df = all_dfs.get(key)
if df is None:
st.warning("No results file found for this combination.")
st.stop()
# ── Header ─────────────────────────────────────────────────────────────────
st.markdown("## MMS Evaluation")
st.markdown(
f"<span style='font-size:0.8rem;color:#aaa'>"
f"{sel_model} &nbsp;Β·&nbsp; {sel_dataset} &nbsp;Β·&nbsp; {sel_split} &nbsp;Β·&nbsp; {sel_decode}"
f"</span>",
unsafe_allow_html=True,
)
st.markdown("<hr>", unsafe_allow_html=True)
if has_norm:
col1, col2, col3, col4, col5 = st.columns(5)
col1.metric("Examples", len(df))
col2.metric("Avg WER", f"{df['wer'].mean():.3f}")
col3.metric("Avg CER", f"{df['cer'].mean():.3f}")
col4.metric("Avg WER~", f"{df['wer_normalized'].mean():.3f}")
col5.metric("Avg CER~", f"{df['cer_normalized'].mean():.3f}")
else:
col1, col2, col3 = st.columns(3)
col1.metric("Examples", len(df))
col2.metric("Avg WER", f"{df['wer'].mean():.3f}")
col3.metric("Avg CER", f"{df['cer'].mean():.3f}")
st.markdown("<hr>", unsafe_allow_html=True)
# ── Sort ───────────────────────────────────────────────────────────────────
_sort_map = {
"WER ↑": ("wer", True), "WER ↓": ("wer", False),
"CER ↑": ("cer", True), "CER ↓": ("cer", False),
"WER~ ↑": ("wer_normalized", True), "WER~ ↓": ("wer_normalized", False),
"CER~ ↑": ("cer_normalized", True), "CER~ ↓": ("cer_normalized", False),
}
if sort_by in _sort_map:
col, asc = _sort_map[sort_by]
df = df.sort_values(col, ascending=asc).reset_index(drop=True)
# ── Beam results for inline comparison ────────────────────────────────────
beam_df = (
all_dfs.get((sel_model, sel_dataset, sel_split, "beam"))
if sel_decode != "beam" else None
)
# ── Examples ───────────────────────────────────────────────────────────────
for i, row in df.head(max_examples).iterrows():
audio_path = row["path"]
with st.container(border=True):
# Header: filename + badges
norm_badges = (
f'&nbsp;{wer_badge(row["wer_normalized"], "WER~")}{cer_badge(row["cer_normalized"], "CER~")}'
if has_norm else ""
)
st.markdown(
f'<div class="ex-header">'
f'<span class="ex-filename" style="font-size:0.72rem;color:#999">{os.path.join(os.path.basename(os.path.dirname(audio_path)), os.path.basename(audio_path))}</span>'
f'<div class="ex-scores" style="margin:0">{wer_badge(row["wer"])}{cer_badge(row["cer"])}{norm_badges}</div>'
f'</div>',
unsafe_allow_html=True,
)
left, right = st.columns([1, 2], gap="large")
with left:
audio_bytes = get_audio_bytes(audio_path)
if audio_bytes:
st.audio(audio_bytes, format="audio/wav")
else:
st.caption("audio not found")
tc1, tc2 = st.columns(2)
hl = tc1.toggle("Highlight", key=f"hl_{i}", value=False)
al = tc2.toggle("Align", key=f"al_{i}", value=False)
if st.button("Compare all runs", key=f"zoom_{i}", use_container_width=True):
show_example_detail(audio_path, all_dfs, hl)
if has_norm and st.button("Normalized view", key=f"norm_{i}", use_container_width=True):
show_normalized_detail(
str(row["ref_norm"]), str(row["pred_norm"]),
float(row["wer_normalized"]), float(row["cer_normalized"]),
hl,
)
with right:
labels = (
'<div class="ex-card-cols" style="margin-bottom:0.25rem">'
'<div class="ex-col"><div class="ex-label">Reference</div></div>'
'<div class="ex-col"><div class="ex-label">Prediction</div></div>'
'</div>'
)
body = render_aligned_cols(str(row['ref']), str(row['pred']), hl) if al \
else render_flat_cols(str(row['ref']), str(row['pred']), hl)
st.markdown(f'{labels}{body}', unsafe_allow_html=True)
if beam_df is not None:
beam_match = beam_df[beam_df["path"] == audio_path]
if not beam_match.empty:
beam_row = beam_match.iloc[0]
beam_pred_block = zoom_pred_html(str(row["ref"]), str(beam_row["pred"]), hl)
beam_scores = (
f'<div class="ex-scores">'
f'{wer_badge(beam_row["wer"])}{cer_badge(beam_row["cer"])}'
f'</div>'
)
st.markdown(
f'<hr style="margin:0.5rem 0">'
f'<div class="zoom-row">'
f'<div class="ex-label">Beam</div>'
f'{beam_pred_block}'
f'{beam_scores}'
f'</div>',
unsafe_allow_html=True,
)