"""TotTalk Cry Eval — Gradio web UI.""" from __future__ import annotations from collections import Counter import gradio as gr import librosa import numpy as np from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction from models.ensemble import EnsembleClassifier, compute_consensus # ── Load models at startup (cached in process) ─────────────────────────────── ensemble = EnsembleClassifier(use_yamnet_gate=True) ensemble.load_all() # ── Core analysis function ──────────────────────────────────────────────────── def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str: """Accept audio from Gradio, run ensemble, return styled HTML.""" if audio_tuple is None: return _wrap("Upload or record audio to get started.") sr, data = audio_tuple # Gradio gives int16 or float — normalize to float32 if data.dtype != np.float32: data = data.astype(np.float32) / max(np.abs(data).max(), 1) # Mono if data.ndim > 1: data = data.mean(axis=1) # Resample to 16 kHz if sr != SAMPLE_RATE: data = resample(data, sr, SAMPLE_RATE) # Pick the loudest 1-second window window_len = SAMPLE_RATE hop = window_len // 2 best_window = None best_rms = 0.0 for start in range(0, len(data) - window_len + 1, hop): chunk = data[start : start + window_len] rms = float(np.sqrt(np.mean(chunk**2))) if rms > best_rms: best_rms = rms best_window = chunk if best_window is None or is_silent(best_window): return _card("Result", "No cry detected", "The audio seems silent or doesn't contain a baby cry.") best_window = normalize_audio(best_window) predictions = ensemble.predict_all(best_window, SAMPLE_RATE) return _render_results(predictions) # ── HTML renderers ──────────────────────────────────────────────────────────── def _render_results(predictions: list[CryPrediction]) -> str: """Build the full results HTML.""" parts: list[str] = [] # Consensus consensus_text = compute_consensus(predictions) if consensus_text: valid = [p.label for p in predictions if p.model_name != "YAMNet-detector" and not p.error and p.label not in ("no_cry", "timeout", "error")] winning = Counter(valid).most_common(1)[0][0] if valid else "" advice = LABEL_MEANING.get(winning, "") parts.append(_card("Consensus", consensus_text, advice)) # Model breakdown parts.append('
Model breakdown
') for pred in predictions: if pred.label == "no_cry" and pred.confidence == 0.0: continue emoji = LABEL_EMOJI.get(pred.label, "") label = pred.label.replace("_", " ").title() pct = int(pred.confidence * 100) parts.append( f'
' f'
{pred.model_name}
' f'
{emoji} {label}
' f'
' f'{pct}% confidence · {pred.latency_ms:.0f} ms
' f'
' f'
' ) return "\n".join(parts) def _card(title: str, main: str, sub: str = "") -> str: """A centered highlight card.""" sub_html = f'
{sub}
' if sub else "" return ( f'
' f'
{title}
' f'
{main}
' f'{sub_html}
' ) def _wrap(msg: str) -> str: return f'
{msg}
' # ── Custom CSS for dark monochrome look ─────────────────────────────────────── CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); body, .gradio-container { font-family: 'Inter', sans-serif !important; } .gradio-container { max-width: 720px !important; margin: auto !important; } footer { display: none !important; } h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; } """ # ── Theme ───────────────────────────────────────────────────────────────────── THEME = gr.themes.Base( primary_hue=gr.themes.colors.gray, secondary_hue=gr.themes.colors.gray, neutral_hue=gr.themes.colors.gray, font=gr.themes.GoogleFont("Inter"), ).set( body_background_fill="#0a0a0a", body_background_fill_dark="#0a0a0a", block_background_fill="#111111", block_background_fill_dark="#111111", block_border_color="#222222", block_border_color_dark="#222222", block_label_text_color="#666666", block_label_text_color_dark="#666666", block_title_text_color="#e0e0e0", block_title_text_color_dark="#e0e0e0", body_text_color="#e0e0e0", body_text_color_dark="#e0e0e0", body_text_color_subdued="#666666", body_text_color_subdued_dark="#666666", button_primary_background_fill="transparent", button_primary_background_fill_dark="transparent", button_primary_border_color="#222222", button_primary_border_color_dark="#222222", button_primary_text_color="#e0e0e0", button_primary_text_color_dark="#e0e0e0", input_background_fill="#111111", input_background_fill_dark="#111111", input_border_color="#222222", input_border_color_dark="#222222", ) # ── App ─────────────────────────────────────────────────────────────────────── with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk · Cry Classifier") as app: gr.Markdown("# 👶 TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.") with gr.Tabs(): with gr.TabItem("🎙 Record"): mic_input = gr.Audio( sources=["microphone"], type="numpy", label="Record from mic", ) mic_btn = gr.Button("Analyze recording", variant="primary", size="lg") with gr.TabItem("📁 Upload file"): file_input = gr.Audio( sources=["upload"], type="numpy", label="Upload WAV / MP3 / FLAC", ) file_btn = gr.Button("Analyze file", variant="primary", size="lg") output = gr.HTML( value=_wrap("Upload or record audio above, then click Analyze."), label="Results", ) mic_btn.click(fn=analyze, inputs=mic_input, outputs=output) file_btn.click(fn=analyze, inputs=file_input, outputs=output) gr.Markdown( '

' "TotTalk Cry Eval · Open-source multi-model comparison tool · " "Models run server-side — your audio is not stored.

" ) if __name__ == "__main__": app.launch()