Spaces:
Sleeping
Sleeping
File size: 8,544 Bytes
66c65bc 9c68f51 66c65bc 9c68f51 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """TotTalk Cry Eval β Gradio web UI."""
from __future__ import annotations
from collections import Counter
import gradio as gr
import librosa
import numpy as np
from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample
from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import EnsembleClassifier, compute_consensus
# ββ Load models at startup (cached in process) βββββββββββββββββββββββββββββββ
ensemble = EnsembleClassifier(use_yamnet_gate=True)
ensemble.load_all()
# ββ Core analysis function ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str:
"""Accept audio from Gradio, run ensemble, return styled HTML."""
if audio_tuple is None:
return _wrap("Upload or record audio to get started.")
sr, data = audio_tuple
# Gradio gives int16 or float β normalize to float32
if data.dtype != np.float32:
data = data.astype(np.float32) / max(np.abs(data).max(), 1)
# Mono
if data.ndim > 1:
data = data.mean(axis=1)
# Resample to 16 kHz
if sr != SAMPLE_RATE:
data = resample(data, sr, SAMPLE_RATE)
# Pick the loudest 1-second window
window_len = SAMPLE_RATE
hop = window_len // 2
best_window = None
best_rms = 0.0
for start in range(0, len(data) - window_len + 1, hop):
chunk = data[start : start + window_len]
rms = float(np.sqrt(np.mean(chunk**2)))
if rms > best_rms:
best_rms = rms
best_window = chunk
if best_window is None or is_silent(best_window):
return _card("Result", "No cry detected",
"The audio seems silent or doesn't contain a baby cry.")
best_window = normalize_audio(best_window)
predictions = ensemble.predict_all(best_window, SAMPLE_RATE)
return _render_results(predictions)
# ββ HTML renderers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _render_results(predictions: list[CryPrediction]) -> str:
"""Build the full results HTML."""
parts: list[str] = []
# Consensus
consensus_text = compute_consensus(predictions)
if consensus_text:
valid = [p.label for p in predictions
if p.model_name != "YAMNet-detector"
and not p.error
and p.label not in ("no_cry", "timeout", "error")]
winning = Counter(valid).most_common(1)[0][0] if valid else ""
advice = LABEL_MEANING.get(winning, "")
parts.append(_card("Consensus", consensus_text, advice))
# Model breakdown
parts.append('<div style="margin-top:1.25rem; font-size:0.7rem; '
'text-transform:uppercase; letter-spacing:0.08em; '
'color:#666; font-weight:500;">Model breakdown</div>')
for pred in predictions:
if pred.label == "no_cry" and pred.confidence == 0.0:
continue
emoji = LABEL_EMOJI.get(pred.label, "")
label = pred.label.replace("_", " ").title()
pct = int(pred.confidence * 100)
parts.append(
f'<div style="background:#111; border:1px solid #222; '
f'border-radius:12px; padding:1.1rem 1.4rem; margin-top:0.6rem;">'
f'<div style="font-size:0.7rem; text-transform:uppercase; '
f'letter-spacing:0.08em; color:#666; font-weight:500;">{pred.model_name}</div>'
f'<div style="font-size:1.3rem; font-weight:600; color:#fff; '
f'margin-top:0.15rem;">{emoji} {label}</div>'
f'<div style="font-size:0.8rem; color:#666; margin-top:0.1rem;">'
f'{pct}% confidence Β· {pred.latency_ms:.0f} ms</div>'
f'<div style="background:#1a1a1a; border-radius:4px; height:6px; '
f'margin-top:0.4rem;">'
f'<div style="background:#fff; border-radius:4px; height:6px; '
f'width:{pct}%;"></div></div></div>'
)
return "\n".join(parts)
def _card(title: str, main: str, sub: str = "") -> str:
"""A centered highlight card."""
sub_html = f'<div style="font-size:0.85rem; color:#666; margin-top:0.5rem; font-style:italic;">{sub}</div>' if sub else ""
return (
f'<div style="background:#111; border:1px solid #333; border-radius:12px; '
f'padding:1.5rem; text-align:center;">'
f'<div style="font-size:0.7rem; text-transform:uppercase; '
f'letter-spacing:0.1em; color:#666;">{title}</div>'
f'<div style="font-size:1.7rem; font-weight:300; color:#fff; '
f'margin-top:0.25rem;">{main}</div>'
f'{sub_html}</div>'
)
def _wrap(msg: str) -> str:
return f'<div style="text-align:center; color:#666; padding:2rem 0;">{msg}</div>'
# ββ Custom CSS for dark monochrome look βββββββββββββββββββββββββββββββββββββββ
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
.gradio-container { max-width: 720px !important; margin: auto !important; }
footer { display: none !important; }
h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; }
"""
# ββ Theme βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.gray,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#0a0a0a",
body_background_fill_dark="#0a0a0a",
block_background_fill="#111111",
block_background_fill_dark="#111111",
block_border_color="#222222",
block_border_color_dark="#222222",
block_label_text_color="#666666",
block_label_text_color_dark="#666666",
block_title_text_color="#e0e0e0",
block_title_text_color_dark="#e0e0e0",
body_text_color="#e0e0e0",
body_text_color_dark="#e0e0e0",
body_text_color_subdued="#666666",
body_text_color_subdued_dark="#666666",
button_primary_background_fill="transparent",
button_primary_background_fill_dark="transparent",
button_primary_border_color="#222222",
button_primary_border_color_dark="#222222",
button_primary_text_color="#e0e0e0",
button_primary_text_color_dark="#e0e0e0",
input_background_fill="#111111",
input_background_fill_dark="#111111",
input_border_color="#222222",
input_border_color_dark="#222222",
)
# ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk Β· Cry Classifier") as app:
gr.Markdown("# πΆ TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.")
with gr.Tabs():
with gr.TabItem("π Record"):
mic_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Record from mic",
)
mic_btn = gr.Button("Analyze recording", variant="primary", size="lg")
with gr.TabItem("π Upload file"):
file_input = gr.Audio(
sources=["upload"],
type="numpy",
label="Upload WAV / MP3 / FLAC",
)
file_btn = gr.Button("Analyze file", variant="primary", size="lg")
output = gr.HTML(
value=_wrap("Upload or record audio above, then click Analyze."),
label="Results",
)
mic_btn.click(fn=analyze, inputs=mic_input, outputs=output)
file_btn.click(fn=analyze, inputs=file_input, outputs=output)
gr.Markdown(
'<p style="text-align:center; font-size:0.75rem; color:#444; margin-top:2rem;">'
"TotTalk Cry Eval Β· Open-source multi-model comparison tool Β· "
"Models run server-side β your audio is not stored.</p>"
)
if __name__ == "__main__":
app.launch()
|