tot-talk / app.py
grungecoder's picture
Fix: move theme/css to Blocks() for HF Spaces compat
9c68f51
"""TotTalk Cry Eval β€” Gradio web UI."""
from __future__ import annotations
from collections import Counter
import gradio as gr
import librosa
import numpy as np
from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample
from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import EnsembleClassifier, compute_consensus
# ── Load models at startup (cached in process) ───────────────────────────────
ensemble = EnsembleClassifier(use_yamnet_gate=True)
ensemble.load_all()
# ── Core analysis function ────────────────────────────────────────────────────
def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str:
"""Accept audio from Gradio, run ensemble, return styled HTML."""
if audio_tuple is None:
return _wrap("Upload or record audio to get started.")
sr, data = audio_tuple
# Gradio gives int16 or float β€” normalize to float32
if data.dtype != np.float32:
data = data.astype(np.float32) / max(np.abs(data).max(), 1)
# Mono
if data.ndim > 1:
data = data.mean(axis=1)
# Resample to 16 kHz
if sr != SAMPLE_RATE:
data = resample(data, sr, SAMPLE_RATE)
# Pick the loudest 1-second window
window_len = SAMPLE_RATE
hop = window_len // 2
best_window = None
best_rms = 0.0
for start in range(0, len(data) - window_len + 1, hop):
chunk = data[start : start + window_len]
rms = float(np.sqrt(np.mean(chunk**2)))
if rms > best_rms:
best_rms = rms
best_window = chunk
if best_window is None or is_silent(best_window):
return _card("Result", "No cry detected",
"The audio seems silent or doesn't contain a baby cry.")
best_window = normalize_audio(best_window)
predictions = ensemble.predict_all(best_window, SAMPLE_RATE)
return _render_results(predictions)
# ── HTML renderers ────────────────────────────────────────────────────────────
def _render_results(predictions: list[CryPrediction]) -> str:
"""Build the full results HTML."""
parts: list[str] = []
# Consensus
consensus_text = compute_consensus(predictions)
if consensus_text:
valid = [p.label for p in predictions
if p.model_name != "YAMNet-detector"
and not p.error
and p.label not in ("no_cry", "timeout", "error")]
winning = Counter(valid).most_common(1)[0][0] if valid else ""
advice = LABEL_MEANING.get(winning, "")
parts.append(_card("Consensus", consensus_text, advice))
# Model breakdown
parts.append('<div style="margin-top:1.25rem; font-size:0.7rem; '
'text-transform:uppercase; letter-spacing:0.08em; '
'color:#666; font-weight:500;">Model breakdown</div>')
for pred in predictions:
if pred.label == "no_cry" and pred.confidence == 0.0:
continue
emoji = LABEL_EMOJI.get(pred.label, "")
label = pred.label.replace("_", " ").title()
pct = int(pred.confidence * 100)
parts.append(
f'<div style="background:#111; border:1px solid #222; '
f'border-radius:12px; padding:1.1rem 1.4rem; margin-top:0.6rem;">'
f'<div style="font-size:0.7rem; text-transform:uppercase; '
f'letter-spacing:0.08em; color:#666; font-weight:500;">{pred.model_name}</div>'
f'<div style="font-size:1.3rem; font-weight:600; color:#fff; '
f'margin-top:0.15rem;">{emoji} {label}</div>'
f'<div style="font-size:0.8rem; color:#666; margin-top:0.1rem;">'
f'{pct}% confidence Β· {pred.latency_ms:.0f} ms</div>'
f'<div style="background:#1a1a1a; border-radius:4px; height:6px; '
f'margin-top:0.4rem;">'
f'<div style="background:#fff; border-radius:4px; height:6px; '
f'width:{pct}%;"></div></div></div>'
)
return "\n".join(parts)
def _card(title: str, main: str, sub: str = "") -> str:
"""A centered highlight card."""
sub_html = f'<div style="font-size:0.85rem; color:#666; margin-top:0.5rem; font-style:italic;">{sub}</div>' if sub else ""
return (
f'<div style="background:#111; border:1px solid #333; border-radius:12px; '
f'padding:1.5rem; text-align:center;">'
f'<div style="font-size:0.7rem; text-transform:uppercase; '
f'letter-spacing:0.1em; color:#666;">{title}</div>'
f'<div style="font-size:1.7rem; font-weight:300; color:#fff; '
f'margin-top:0.25rem;">{main}</div>'
f'{sub_html}</div>'
)
def _wrap(msg: str) -> str:
return f'<div style="text-align:center; color:#666; padding:2rem 0;">{msg}</div>'
# ── Custom CSS for dark monochrome look ───────────────────────────────────────
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
.gradio-container { max-width: 720px !important; margin: auto !important; }
footer { display: none !important; }
h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; }
"""
# ── Theme ─────────────────────────────────────────────────────────────────────
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.gray,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#0a0a0a",
body_background_fill_dark="#0a0a0a",
block_background_fill="#111111",
block_background_fill_dark="#111111",
block_border_color="#222222",
block_border_color_dark="#222222",
block_label_text_color="#666666",
block_label_text_color_dark="#666666",
block_title_text_color="#e0e0e0",
block_title_text_color_dark="#e0e0e0",
body_text_color="#e0e0e0",
body_text_color_dark="#e0e0e0",
body_text_color_subdued="#666666",
body_text_color_subdued_dark="#666666",
button_primary_background_fill="transparent",
button_primary_background_fill_dark="transparent",
button_primary_border_color="#222222",
button_primary_border_color_dark="#222222",
button_primary_text_color="#e0e0e0",
button_primary_text_color_dark="#e0e0e0",
input_background_fill="#111111",
input_background_fill_dark="#111111",
input_border_color="#222222",
input_border_color_dark="#222222",
)
# ── App ───────────────────────────────────────────────────────────────────────
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk Β· Cry Classifier") as app:
gr.Markdown("# πŸ‘Ά TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.")
with gr.Tabs():
with gr.TabItem("πŸŽ™ Record"):
mic_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Record from mic",
)
mic_btn = gr.Button("Analyze recording", variant="primary", size="lg")
with gr.TabItem("πŸ“ Upload file"):
file_input = gr.Audio(
sources=["upload"],
type="numpy",
label="Upload WAV / MP3 / FLAC",
)
file_btn = gr.Button("Analyze file", variant="primary", size="lg")
output = gr.HTML(
value=_wrap("Upload or record audio above, then click Analyze."),
label="Results",
)
mic_btn.click(fn=analyze, inputs=mic_input, outputs=output)
file_btn.click(fn=analyze, inputs=file_input, outputs=output)
gr.Markdown(
'<p style="text-align:center; font-size:0.75rem; color:#444; margin-top:2rem;">'
"TotTalk Cry Eval Β· Open-source multi-model comparison tool Β· "
"Models run server-side β€” your audio is not stored.</p>"
)
if __name__ == "__main__":
app.launch()