File size: 8,544 Bytes
66c65bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c68f51
66c65bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c68f51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""TotTalk Cry Eval β€” Gradio web UI."""

from __future__ import annotations

from collections import Counter

import gradio as gr
import librosa
import numpy as np

from audio.preprocess import SAMPLE_RATE, is_silent, normalize_audio, resample
from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import EnsembleClassifier, compute_consensus

# ── Load models at startup (cached in process) ───────────────────────────────
ensemble = EnsembleClassifier(use_yamnet_gate=True)
ensemble.load_all()


# ── Core analysis function ────────────────────────────────────────────────────
def analyze(audio_tuple: tuple[int, np.ndarray] | None) -> str:
    """Accept audio from Gradio, run ensemble, return styled HTML."""
    if audio_tuple is None:
        return _wrap("Upload or record audio to get started.")

    sr, data = audio_tuple

    # Gradio gives int16 or float β€” normalize to float32
    if data.dtype != np.float32:
        data = data.astype(np.float32) / max(np.abs(data).max(), 1)

    # Mono
    if data.ndim > 1:
        data = data.mean(axis=1)

    # Resample to 16 kHz
    if sr != SAMPLE_RATE:
        data = resample(data, sr, SAMPLE_RATE)

    # Pick the loudest 1-second window
    window_len = SAMPLE_RATE
    hop = window_len // 2
    best_window = None
    best_rms = 0.0

    for start in range(0, len(data) - window_len + 1, hop):
        chunk = data[start : start + window_len]
        rms = float(np.sqrt(np.mean(chunk**2)))
        if rms > best_rms:
            best_rms = rms
            best_window = chunk

    if best_window is None or is_silent(best_window):
        return _card("Result", "No cry detected",
                      "The audio seems silent or doesn't contain a baby cry.")

    best_window = normalize_audio(best_window)
    predictions = ensemble.predict_all(best_window, SAMPLE_RATE)

    return _render_results(predictions)


# ── HTML renderers ────────────────────────────────────────────────────────────
def _render_results(predictions: list[CryPrediction]) -> str:
    """Build the full results HTML."""
    parts: list[str] = []

    # Consensus
    consensus_text = compute_consensus(predictions)
    if consensus_text:
        valid = [p.label for p in predictions
                 if p.model_name != "YAMNet-detector"
                 and not p.error
                 and p.label not in ("no_cry", "timeout", "error")]
        winning = Counter(valid).most_common(1)[0][0] if valid else ""
        advice = LABEL_MEANING.get(winning, "")
        parts.append(_card("Consensus", consensus_text, advice))

    # Model breakdown
    parts.append('<div style="margin-top:1.25rem; font-size:0.7rem; '
                 'text-transform:uppercase; letter-spacing:0.08em; '
                 'color:#666; font-weight:500;">Model breakdown</div>')

    for pred in predictions:
        if pred.label == "no_cry" and pred.confidence == 0.0:
            continue
        emoji = LABEL_EMOJI.get(pred.label, "")
        label = pred.label.replace("_", " ").title()
        pct = int(pred.confidence * 100)
        parts.append(
            f'<div style="background:#111; border:1px solid #222; '
            f'border-radius:12px; padding:1.1rem 1.4rem; margin-top:0.6rem;">'
            f'<div style="font-size:0.7rem; text-transform:uppercase; '
            f'letter-spacing:0.08em; color:#666; font-weight:500;">{pred.model_name}</div>'
            f'<div style="font-size:1.3rem; font-weight:600; color:#fff; '
            f'margin-top:0.15rem;">{emoji} {label}</div>'
            f'<div style="font-size:0.8rem; color:#666; margin-top:0.1rem;">'
            f'{pct}% confidence Β· {pred.latency_ms:.0f} ms</div>'
            f'<div style="background:#1a1a1a; border-radius:4px; height:6px; '
            f'margin-top:0.4rem;">'
            f'<div style="background:#fff; border-radius:4px; height:6px; '
            f'width:{pct}%;"></div></div></div>'
        )

    return "\n".join(parts)


def _card(title: str, main: str, sub: str = "") -> str:
    """A centered highlight card."""
    sub_html = f'<div style="font-size:0.85rem; color:#666; margin-top:0.5rem; font-style:italic;">{sub}</div>' if sub else ""
    return (
        f'<div style="background:#111; border:1px solid #333; border-radius:12px; '
        f'padding:1.5rem; text-align:center;">'
        f'<div style="font-size:0.7rem; text-transform:uppercase; '
        f'letter-spacing:0.1em; color:#666;">{title}</div>'
        f'<div style="font-size:1.7rem; font-weight:300; color:#fff; '
        f'margin-top:0.25rem;">{main}</div>'
        f'{sub_html}</div>'
    )


def _wrap(msg: str) -> str:
    return f'<div style="text-align:center; color:#666; padding:2rem 0;">{msg}</div>'


# ── Custom CSS for dark monochrome look ───────────────────────────────────────
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
.gradio-container { max-width: 720px !important; margin: auto !important; }
footer { display: none !important; }
h1 { font-weight: 300 !important; letter-spacing: -0.03em !important; }
"""

# ── Theme ─────────────────────────────────────────────────────────────────────
THEME = gr.themes.Base(
    primary_hue=gr.themes.colors.gray,
    secondary_hue=gr.themes.colors.gray,
    neutral_hue=gr.themes.colors.gray,
    font=gr.themes.GoogleFont("Inter"),
).set(
    body_background_fill="#0a0a0a",
    body_background_fill_dark="#0a0a0a",
    block_background_fill="#111111",
    block_background_fill_dark="#111111",
    block_border_color="#222222",
    block_border_color_dark="#222222",
    block_label_text_color="#666666",
    block_label_text_color_dark="#666666",
    block_title_text_color="#e0e0e0",
    block_title_text_color_dark="#e0e0e0",
    body_text_color="#e0e0e0",
    body_text_color_dark="#e0e0e0",
    body_text_color_subdued="#666666",
    body_text_color_subdued_dark="#666666",
    button_primary_background_fill="transparent",
    button_primary_background_fill_dark="transparent",
    button_primary_border_color="#222222",
    button_primary_border_color_dark="#222222",
    button_primary_text_color="#e0e0e0",
    button_primary_text_color_dark="#e0e0e0",
    input_background_fill="#111111",
    input_background_fill_dark="#111111",
    input_border_color="#222222",
    input_border_color_dark="#222222",
)

# ── App ───────────────────────────────────────────────────────────────────────
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="TotTalk Β· Cry Classifier") as app:

    gr.Markdown("# πŸ‘Ά TotTalk\nUpload or record a baby cry and get an instant multi-model analysis.")

    with gr.Tabs():
        with gr.TabItem("πŸŽ™ Record"):
            mic_input = gr.Audio(
                sources=["microphone"],
                type="numpy",
                label="Record from mic",
            )
            mic_btn = gr.Button("Analyze recording", variant="primary", size="lg")
        with gr.TabItem("πŸ“ Upload file"):
            file_input = gr.Audio(
                sources=["upload"],
                type="numpy",
                label="Upload WAV / MP3 / FLAC",
            )
            file_btn = gr.Button("Analyze file", variant="primary", size="lg")

    output = gr.HTML(
        value=_wrap("Upload or record audio above, then click Analyze."),
        label="Results",
    )

    mic_btn.click(fn=analyze, inputs=mic_input, outputs=output)
    file_btn.click(fn=analyze, inputs=file_input, outputs=output)

    gr.Markdown(
        '<p style="text-align:center; font-size:0.75rem; color:#444; margin-top:2rem;">'
        "TotTalk Cry Eval Β· Open-source multi-model comparison tool Β· "
        "Models run server-side β€” your audio is not stored.</p>"
    )

if __name__ == "__main__":
    app.launch()