"""TotTalk Cry Eval — Gradio web UI."""
from __future__ import annotations
from collections import Counter
import gradio as gr
import librosa
import numpy as np
from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample
from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import EnsembleClassifier, compute_consensus
# ── Load models at startup (cached in process) ───────────────────────────────
ensemble = EnsembleClassifier(use_yamnet_gate=True)
ensemble.load_all()
# ── Core analysis function ────────────────────────────────────────────────────
def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str:
"""Accept audio from Gradio, run ensemble, return styled HTML."""
if audio_tuple is None:
return _wrap("Upload or record audio to get started.")
sr, data = audio_tuple
# Gradio gives int16 or float — normalize to float32
if data.dtype != np.float32:
data = data.astype(np.float32) / max(np.abs(data).max(), 1)
# Mono
if data.ndim > 1:
data = data.mean(axis=1)
# Resample to 16 kHz
if sr != SAMPLE_RATE:
data = resample(data, sr, SAMPLE_RATE)
# Pick the loudest 1-second window
window_len = SAMPLE_RATE
hop = window_len // 2
best_window = None
best_rms = 0.0
for start in range(0, len(data) - window_len + 1, hop):
chunk = data[start : start + window_len]
rms = float(np.sqrt(np.mean(chunk**2)))
if rms > best_rms:
best_rms = rms
best_window = chunk
if best_window is None or is_silent(best_window):
return _card("Result", "No cry detected",
"The audio seems silent or doesn't contain a baby cry.")
best_window = normalize_audio(best_window)
predictions = ensemble.predict_all(best_window, SAMPLE_RATE)
return _render_results(predictions)
# ── HTML renderers ────────────────────────────────────────────────────────────
def _render_results(predictions: list[CryPrediction]) -> str:
"""Build the full results HTML."""
parts: list[str] = []
# Consensus
consensus_text = compute_consensus(predictions)
if consensus_text:
valid = [p.label for p in predictions
if p.model_name != "YAMNet-detector"
and not p.error
and p.label not in ("no_cry", "timeout", "error")]
winning = Counter(valid).most_common(1)[0][0] if valid else ""
advice = LABEL_MEANING.get(winning, "")
parts.append(_card("Consensus", consensus_text, advice))
# Model breakdown
parts.append('
Model breakdown
')
for pred in predictions:
if pred.label == "no_cry" and pred.confidence == 0.0:
continue
emoji = LABEL_EMOJI.get(pred.label, "")
label = pred.label.replace("_", " ").title()
pct = int(pred.confidence * 100)
parts.append(
f''
f'
{pred.model_name}
'
f'
{emoji} {label}
'
f'
'
f'{pct}% confidence · {pred.latency_ms:.0f} ms
'
f'
'
)
return "\n".join(parts)
def _card(title: str, main: str, sub: str = "") -> str:
"""A centered highlight card."""
sub_html = f'{sub}
' if sub else ""
return (
f''
f'
{title}
'
f'
{main}
'
f'{sub_html}
'
)
def _wrap(msg: str) -> str:
return f'{msg}
'
# ── Custom CSS for dark monochrome look ───────────────────────────────────────
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
.gradio-container { max-width: 720px !important; margin: auto !important; }
footer { display: none !important; }
h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; }
"""
# ── Theme ─────────────────────────────────────────────────────────────────────
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.gray,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#0a0a0a",
body_background_fill_dark="#0a0a0a",
block_background_fill="#111111",
block_background_fill_dark="#111111",
block_border_color="#222222",
block_border_color_dark="#222222",
block_label_text_color="#666666",
block_label_text_color_dark="#666666",
block_title_text_color="#e0e0e0",
block_title_text_color_dark="#e0e0e0",
body_text_color="#e0e0e0",
body_text_color_dark="#e0e0e0",
body_text_color_subdued="#666666",
body_text_color_subdued_dark="#666666",
button_primary_background_fill="transparent",
button_primary_background_fill_dark="transparent",
button_primary_border_color="#222222",
button_primary_border_color_dark="#222222",
button_primary_text_color="#e0e0e0",
button_primary_text_color_dark="#e0e0e0",
input_background_fill="#111111",
input_background_fill_dark="#111111",
input_border_color="#222222",
input_border_color_dark="#222222",
)
# ── App ───────────────────────────────────────────────────────────────────────
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk · Cry Classifier") as app:
gr.Markdown("# 👶 TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.")
with gr.Tabs():
with gr.TabItem("🎙 Record"):
mic_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Record from mic",
)
mic_btn = gr.Button("Analyze recording", variant="primary", size="lg")
with gr.TabItem("📁 Upload file"):
file_input = gr.Audio(
sources=["upload"],
type="numpy",
label="Upload WAV / MP3 / FLAC",
)
file_btn = gr.Button("Analyze file", variant="primary", size="lg")
output = gr.HTML(
value=_wrap("Upload or record audio above, then click Analyze."),
label="Results",
)
mic_btn.click(fn=analyze, inputs=mic_input, outputs=output)
file_btn.click(fn=analyze, inputs=file_input, outputs=output)
gr.Markdown(
''
"TotTalk Cry Eval · Open-source multi-model comparison tool · "
"Models run server-side — your audio is not stored.
"
)
if __name__ == "__main__":
app.launch()