Spaces:
Sleeping
Sleeping
| """TotTalk Cry Eval β Gradio web UI.""" | |
| from __future__ import annotations | |
| from collections import Counter | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample | |
| from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction | |
| from models.ensemble import EnsembleClassifier, compute_consensus | |
| # ββ Load models at startup (cached in process) βββββββββββββββββββββββββββββββ | |
| ensemble = EnsembleClassifier(use_yamnet_gate=True) | |
| ensemble.load_all() | |
| # ββ Core analysis function ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str: | |
| """Accept audio from Gradio, run ensemble, return styled HTML.""" | |
| if audio_tuple is None: | |
| return _wrap("Upload or record audio to get started.") | |
| sr, data = audio_tuple | |
| # Gradio gives int16 or float β normalize to float32 | |
| if data.dtype != np.float32: | |
| data = data.astype(np.float32) / max(np.abs(data).max(), 1) | |
| # Mono | |
| if data.ndim > 1: | |
| data = data.mean(axis=1) | |
| # Resample to 16 kHz | |
| if sr != SAMPLE_RATE: | |
| data = resample(data, sr, SAMPLE_RATE) | |
| # Pick the loudest 1-second window | |
| window_len = SAMPLE_RATE | |
| hop = window_len // 2 | |
| best_window = None | |
| best_rms = 0.0 | |
| for start in range(0, len(data) - window_len + 1, hop): | |
| chunk = data[start : start + window_len] | |
| rms = float(np.sqrt(np.mean(chunk**2))) | |
| if rms > best_rms: | |
| best_rms = rms | |
| best_window = chunk | |
| if best_window is None or is_silent(best_window): | |
| return _card("Result", "No cry detected", | |
| "The audio seems silent or doesn't contain a baby cry.") | |
| best_window = normalize_audio(best_window) | |
| predictions = ensemble.predict_all(best_window, SAMPLE_RATE) | |
| return _render_results(predictions) | |
| # ββ HTML renderers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _render_results(predictions: list[CryPrediction]) -> str: | |
| """Build the full results HTML.""" | |
| parts: list[str] = [] | |
| # Consensus | |
| consensus_text = compute_consensus(predictions) | |
| if consensus_text: | |
| valid = [p.label for p in predictions | |
| if p.model_name != "YAMNet-detector" | |
| and not p.error | |
| and p.label not in ("no_cry", "timeout", "error")] | |
| winning = Counter(valid).most_common(1)[0][0] if valid else "" | |
| advice = LABEL_MEANING.get(winning, "") | |
| parts.append(_card("Consensus", consensus_text, advice)) | |
| # Model breakdown | |
| parts.append('<div style="margin-top:1.25rem; font-size:0.7rem; ' | |
| 'text-transform:uppercase; letter-spacing:0.08em; ' | |
| 'color:#666; font-weight:500;">Model breakdown</div>') | |
| for pred in predictions: | |
| if pred.label == "no_cry" and pred.confidence == 0.0: | |
| continue | |
| emoji = LABEL_EMOJI.get(pred.label, "") | |
| label = pred.label.replace("_", " ").title() | |
| pct = int(pred.confidence * 100) | |
| parts.append( | |
| f'<div style="background:#111; border:1px solid #222; ' | |
| f'border-radius:12px; padding:1.1rem 1.4rem; margin-top:0.6rem;">' | |
| f'<div style="font-size:0.7rem; text-transform:uppercase; ' | |
| f'letter-spacing:0.08em; color:#666; font-weight:500;">{pred.model_name}</div>' | |
| f'<div style="font-size:1.3rem; font-weight:600; color:#fff; ' | |
| f'margin-top:0.15rem;">{emoji} {label}</div>' | |
| f'<div style="font-size:0.8rem; color:#666; margin-top:0.1rem;">' | |
| f'{pct}% confidence Β· {pred.latency_ms:.0f} ms</div>' | |
| f'<div style="background:#1a1a1a; border-radius:4px; height:6px; ' | |
| f'margin-top:0.4rem;">' | |
| f'<div style="background:#fff; border-radius:4px; height:6px; ' | |
| f'width:{pct}%;"></div></div></div>' | |
| ) | |
| return "\n".join(parts) | |
| def _card(title: str, main: str, sub: str = "") -> str: | |
| """A centered highlight card.""" | |
| sub_html = f'<div style="font-size:0.85rem; color:#666; margin-top:0.5rem; font-style:italic;">{sub}</div>' if sub else "" | |
| return ( | |
| f'<div style="background:#111; border:1px solid #333; border-radius:12px; ' | |
| f'padding:1.5rem; text-align:center;">' | |
| f'<div style="font-size:0.7rem; text-transform:uppercase; ' | |
| f'letter-spacing:0.1em; color:#666;">{title}</div>' | |
| f'<div style="font-size:1.7rem; font-weight:300; color:#fff; ' | |
| f'margin-top:0.25rem;">{main}</div>' | |
| f'{sub_html}</div>' | |
| ) | |
| def _wrap(msg: str) -> str: | |
| return f'<div style="text-align:center; color:#666; padding:2rem 0;">{msg}</div>' | |
| # ββ Custom CSS for dark monochrome look βββββββββββββββββββββββββββββββββββββββ | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| body, .gradio-container { font-family: 'Inter', sans-serif !important; } | |
| .gradio-container { max-width: 720px !important; margin: auto !important; } | |
| footer { display: none !important; } | |
| h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; } | |
| """ | |
| # ββ Theme βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| THEME = gr.themes.Base( | |
| primary_hue=gr.themes.colors.gray, | |
| secondary_hue=gr.themes.colors.gray, | |
| neutral_hue=gr.themes.colors.gray, | |
| font=gr.themes.GoogleFont("Inter"), | |
| ).set( | |
| body_background_fill="#0a0a0a", | |
| body_background_fill_dark="#0a0a0a", | |
| block_background_fill="#111111", | |
| block_background_fill_dark="#111111", | |
| block_border_color="#222222", | |
| block_border_color_dark="#222222", | |
| block_label_text_color="#666666", | |
| block_label_text_color_dark="#666666", | |
| block_title_text_color="#e0e0e0", | |
| block_title_text_color_dark="#e0e0e0", | |
| body_text_color="#e0e0e0", | |
| body_text_color_dark="#e0e0e0", | |
| body_text_color_subdued="#666666", | |
| body_text_color_subdued_dark="#666666", | |
| button_primary_background_fill="transparent", | |
| button_primary_background_fill_dark="transparent", | |
| button_primary_border_color="#222222", | |
| button_primary_border_color_dark="#222222", | |
| button_primary_text_color="#e0e0e0", | |
| button_primary_text_color_dark="#e0e0e0", | |
| input_background_fill="#111111", | |
| input_background_fill_dark="#111111", | |
| input_border_color="#222222", | |
| input_border_color_dark="#222222", | |
| ) | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk Β· Cry Classifier") as app: | |
| gr.Markdown("# πΆ TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.") | |
| with gr.Tabs(): | |
| with gr.TabItem("π Record"): | |
| mic_input = gr.Audio( | |
| sources=["microphone"], | |
| type="numpy", | |
| label="Record from mic", | |
| ) | |
| mic_btn = gr.Button("Analyze recording", variant="primary", size="lg") | |
| with gr.TabItem("π Upload file"): | |
| file_input = gr.Audio( | |
| sources=["upload"], | |
| type="numpy", | |
| label="Upload WAV / MP3 / FLAC", | |
| ) | |
| file_btn = gr.Button("Analyze file", variant="primary", size="lg") | |
| output = gr.HTML( | |
| value=_wrap("Upload or record audio above, then click Analyze."), | |
| label="Results", | |
| ) | |
| mic_btn.click(fn=analyze, inputs=mic_input, outputs=output) | |
| file_btn.click(fn=analyze, inputs=file_input, outputs=output) | |
| gr.Markdown( | |
| '<p style="text-align:center; font-size:0.75rem; color:#444; margin-top:2rem;">' | |
| "TotTalk Cry Eval Β· Open-source multi-model comparison tool Β· " | |
| "Models run server-side β your audio is not stored.</p>" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |