File size: 9,521 Bytes
ba86bda
 
 
 
 
 
fb9af37
ba86bda
 
e792433
ba86bda
fb9af37
 
 
ba86bda
 
 
 
 
 
 
 
 
 
 
fb9af37
0f70449
fb9af37
e792433
fb9af37
 
 
ba86bda
 
431e771
fb9af37
ba86bda
 
 
 
e792433
fb9af37
 
e792433
fb9af37
ba86bda
fb9af37
ba86bda
e792433
 
 
ba86bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9af37
ba86bda
fb9af37
 
ba86bda
 
 
 
 
 
 
 
 
 
e792433
ba86bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9af37
0b7e787
ba86bda
0b7e787
ba86bda
0b7e787
ba86bda
0b7e787
ba86bda
 
0b7e787
ba86bda
 
fb9af37
ba86bda
 
 
 
 
 
 
 
 
 
 
 
0b7e787
ba86bda
 
 
 
0b7e787
fb9af37
ba86bda
fb9af37
 
ba86bda
e792433
2b1a086
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
Speech Fluency Analysis - Hugging Face Gradio App
WavLM stutter detection + Whisper transcription.
"""

import os
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import gradio as gr
from datetime import datetime
from transformers import WavLMModel

STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"]

STUTTER_INFO = {
    "Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')",
    "Block": "Complete stoppage of airflow/sound with tension",
    "SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')",
    "WordRep": "Whole word repetition (e.g. 'I-I-I want')",
    "Interjection": "Filler words like 'um', 'uh', 'like'",
}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class WaveLmStutterClassification(nn.Module):
    def __init__(self, num_labels=5):
        super().__init__()
        self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
        self.hidden_size = self.wavlm.config.hidden_size
        for p in self.wavlm.parameters():
            p.requires_grad = False
        self.classifier = nn.Linear(self.hidden_size, num_labels)

    def forward(self, x, attention_mask=None):
        h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state
        return self.classifier(h.mean(dim=1))


wavlm_model = None
whisper_model = None
models_loaded = False


def load_models():
    """Load WavLM checkpoint and Whisper once."""
    global wavlm_model, whisper_model, models_loaded
    if models_loaded:
        return True

    print("Loading WavLM ...")
    wavlm_model = WaveLmStutterClassification(num_labels=5)
    ckpt = "wavlm_stutter_classification_best.pth"
    if os.path.exists(ckpt):
        state = torch.load(ckpt, map_location=DEVICE, weights_only=False)
        if isinstance(state, dict) and "model_state_dict" in state:
            wavlm_model.load_state_dict(state["model_state_dict"])
        else:
            wavlm_model.load_state_dict(state)
    wavlm_model.to(DEVICE).eval()

    print("Loading Whisper ...")
    import whisper
    whisper_model = whisper.load_model("base", device=DEVICE)

    models_loaded = True
    print("Models ready.")
    return True


# FFmpeg explained:
# torchaudio.load() uses FFmpeg under the hood as a system-level library to
# DECODE compressed audio formats (mp3, m4a, ogg, flac) into raw PCM samples.
# FFmpeg is a CLI/OS tool - torchaudio calls it via its C backend.
# The decoded PCM data is then wrapped into a torch.Tensor (the waveform).
#
# Pipeline: audio file -> FFmpeg decodes -> raw samples -> torch.Tensor
#
# packages.txt lists "ffmpeg" so HF Spaces installs it at OS level.

def load_audio(path):
    """Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg)."""
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    return waveform.squeeze(0), 16000


def analyze_chunk(chunk, threshold=0.5):
    """Run WavLM on a single chunk."""
    with torch.no_grad():
        logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE))
        probs = torch.sigmoid(logits).cpu().numpy()[0]
    detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
    prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs]))
    return detected, prob_dict


def analyze_audio(audio_path, threshold, progress=gr.Progress()):
    """Main pipeline: chunk -> WavLM -> Whisper -> formatted results."""
    if audio_path is None:
        return "Upload an audio file first.", "", "", ""

    if isinstance(audio_path, tuple):
        import tempfile, soundfile as sf
        sr, data = audio_path
        tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        sf.write(tmp.name, data, sr)
        audio_path = tmp.name

    progress(0.05, desc="Loading models ...")
    if not models_loaded and not load_models():
        return "Failed to load models.", "", "", ""

    progress(0.15, desc="Loading audio ...")
    waveform, sr = load_audio(audio_path)
    duration = len(waveform) / sr

    progress(0.25, desc="Detecting stutters ...")
    chunk_samples = 3 * sr
    counts = {l: 0 for l in STUTTER_LABELS}
    timeline_rows = []
    total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples)

    for i, start in enumerate(range(0, len(waveform), chunk_samples)):
        progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...")
        end = min(start + chunk_samples, len(waveform))
        chunk = waveform[start:end]
        if len(chunk) < chunk_samples:
            chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))

        detected, probs = analyze_chunk(chunk, threshold)
        for label in detected:
            counts[label] += 1

        time_str = f"{start/sr:.1f}-{end/sr:.1f}s"
        timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs})

    progress(0.75, desc="Transcribing ...")
    transcription = whisper_model.transcribe(audio_path).get("text", "").strip()

    progress(0.90, desc="Building report ...")
    total_stutters = sum(counts.values())
    chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"])
    stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0
    word_count = len(transcription.split()) if transcription else 0
    wpm = (word_count / duration) * 60 if duration > 0 else 0

    severity = (
        "Very Mild" if stutter_pct < 5 else
        "Mild" if stutter_pct < 10 else
        "Moderate" if stutter_pct < 20 else
        "Severe" if stutter_pct < 30 else
        "Very Severe"
    )

    summary_lines = [
        "## Analysis Results\n",
        "| Metric | Value |",
        "|--------|-------|",
        f"| Duration | {duration:.1f}s |",
        f"| Words | {word_count} |",
        f"| Speaking Rate | {wpm:.0f} wpm |",
        f"| Stutter Events | {total_stutters} |",
        f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |",
        f"| Severity | **{severity}** |",
        "",
        "### Stutter Counts",
        "",
    ]
    for label in STUTTER_LABELS:
        c = counts[label]
        bar = "X" * min(c, 20)
        icon = "!" if c > 0 else "o"
        summary_lines.append(f"- {icon} **{label}**: {c}  {bar}")

    summary_md = "\n".join(summary_lines)

    tl_lines = ["| Time | Detected |", "|------|----------|"]
    for row in timeline_rows:
        tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |")
    timeline_md = "\n".join(tl_lines)

    recs = ["## Recommendations\n"]
    if severity in ("Very Mild", "Mild"):
        recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.")
    elif severity == "Moderate":
        recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.")
    else:
        recs.append("- Professional speech-language pathology evaluation is strongly recommended.")

    dominant = max(counts, key=counts.get)
    if counts[dominant] > 0:
        recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}")

    if wpm > 180:
        recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.")

    recs.append("\n### Stutter Type Definitions\n")
    for label, desc in STUTTER_INFO.items():
        recs.append(f"- **{label}**: {desc}")

    recs_md = "\n".join(recs)

    progress(1.0, desc="Done!")
    return summary_md, transcription, timeline_md, recs_md


CUSTOM_CSS = """
.gradio-container { max-width: 960px !important; }
.gr-button-primary { background: #0f766e !important; }
"""

with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:

    gr.Markdown(
        """
        # Speech Fluency Analysis
        Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection)
        and **Whisper** (transcription).

        Supported formats: **WAV, MP3, M4A, FLAC, OGG**
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            audio_in = gr.Audio(label="Upload Audio", type="filepath")
            threshold = gr.Slider(
                0.3, 0.7, value=0.5, step=0.05,
                label="Detection Threshold",
                info="Lower = more sensitive, Higher = more strict",
            )
            btn = gr.Button("Analyze", variant="primary", size="lg")

        with gr.Column(scale=2):
            summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*")

    with gr.Tabs():
        with gr.TabItem("Transcription"):
            trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False)
        with gr.TabItem("Timeline"):
            timeline_out = gr.Markdown()
        with gr.TabItem("Recommendations"):
            recs_out = gr.Markdown()

    gr.Markdown(
        "---\n*Disclaimer: AI-assisted analysis for clinical support only. "
        "Consult a qualified Speech-Language Pathologist for diagnosis.*"
    )

    btn.click(
        fn=analyze_audio,
        inputs=[audio_in, threshold],
        outputs=[summary_out, trans_out, timeline_out, recs_out],
        show_progress="full",
    )

print("Loading models at startup ...")
load_models()

print("Launching Gradio ...")
demo.queue()
demo.launch(ssr_mode=False)