| """ |
| Speech Fluency Analysis - Hugging Face Gradio App |
| WavLM stutter detection + Whisper transcription. |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import numpy as np |
| import gradio as gr |
| from datetime import datetime |
| from transformers import WavLMModel |
|
|
| STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"] |
|
|
| STUTTER_INFO = { |
| "Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')", |
| "Block": "Complete stoppage of airflow/sound with tension", |
| "SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')", |
| "WordRep": "Whole word repetition (e.g. 'I-I-I want')", |
| "Interjection": "Filler words like 'um', 'uh', 'like'", |
| } |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| class WaveLmStutterClassification(nn.Module): |
| def __init__(self, num_labels=5): |
| super().__init__() |
| self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
| self.hidden_size = self.wavlm.config.hidden_size |
| for p in self.wavlm.parameters(): |
| p.requires_grad = False |
| self.classifier = nn.Linear(self.hidden_size, num_labels) |
|
|
| def forward(self, x, attention_mask=None): |
| h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state |
| return self.classifier(h.mean(dim=1)) |
|
|
|
|
| wavlm_model = None |
| whisper_model = None |
| models_loaded = False |
|
|
|
|
| def load_models(): |
| """Load WavLM checkpoint and Whisper once.""" |
| global wavlm_model, whisper_model, models_loaded |
| if models_loaded: |
| return True |
|
|
| print("Loading WavLM ...") |
| wavlm_model = WaveLmStutterClassification(num_labels=5) |
| ckpt = "wavlm_stutter_classification_best.pth" |
| if os.path.exists(ckpt): |
| state = torch.load(ckpt, map_location=DEVICE, weights_only=False) |
| if isinstance(state, dict) and "model_state_dict" in state: |
| wavlm_model.load_state_dict(state["model_state_dict"]) |
| else: |
| wavlm_model.load_state_dict(state) |
| wavlm_model.to(DEVICE).eval() |
|
|
| print("Loading Whisper ...") |
| import whisper |
| whisper_model = whisper.load_model("base", device=DEVICE) |
|
|
| models_loaded = True |
| print("Models ready.") |
| return True |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def load_audio(path): |
| """Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg).""" |
| waveform, sr = torchaudio.load(path) |
| if waveform.size(0) > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| if sr != 16000: |
| waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) |
| return waveform.squeeze(0), 16000 |
|
|
|
|
| def analyze_chunk(chunk, threshold=0.5): |
| """Run WavLM on a single chunk.""" |
| with torch.no_grad(): |
| logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE)) |
| probs = torch.sigmoid(logits).cpu().numpy()[0] |
| detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold] |
| prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs])) |
| return detected, prob_dict |
|
|
|
|
| def analyze_audio(audio_path, threshold, progress=gr.Progress()): |
| """Main pipeline: chunk -> WavLM -> Whisper -> formatted results.""" |
| if audio_path is None: |
| return "Upload an audio file first.", "", "", "" |
|
|
| if isinstance(audio_path, tuple): |
| import tempfile, soundfile as sf |
| sr, data = audio_path |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| sf.write(tmp.name, data, sr) |
| audio_path = tmp.name |
|
|
| progress(0.05, desc="Loading models ...") |
| if not models_loaded and not load_models(): |
| return "Failed to load models.", "", "", "" |
|
|
| progress(0.15, desc="Loading audio ...") |
| waveform, sr = load_audio(audio_path) |
| duration = len(waveform) / sr |
|
|
| progress(0.25, desc="Detecting stutters ...") |
| chunk_samples = 3 * sr |
| counts = {l: 0 for l in STUTTER_LABELS} |
| timeline_rows = [] |
| total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples) |
|
|
| for i, start in enumerate(range(0, len(waveform), chunk_samples)): |
| progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...") |
| end = min(start + chunk_samples, len(waveform)) |
| chunk = waveform[start:end] |
| if len(chunk) < chunk_samples: |
| chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
| detected, probs = analyze_chunk(chunk, threshold) |
| for label in detected: |
| counts[label] += 1 |
|
|
| time_str = f"{start/sr:.1f}-{end/sr:.1f}s" |
| timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs}) |
|
|
| progress(0.75, desc="Transcribing ...") |
| transcription = whisper_model.transcribe(audio_path).get("text", "").strip() |
|
|
| progress(0.90, desc="Building report ...") |
| total_stutters = sum(counts.values()) |
| chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"]) |
| stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0 |
| word_count = len(transcription.split()) if transcription else 0 |
| wpm = (word_count / duration) * 60 if duration > 0 else 0 |
|
|
| severity = ( |
| "Very Mild" if stutter_pct < 5 else |
| "Mild" if stutter_pct < 10 else |
| "Moderate" if stutter_pct < 20 else |
| "Severe" if stutter_pct < 30 else |
| "Very Severe" |
| ) |
|
|
| summary_lines = [ |
| "## Analysis Results\n", |
| "| Metric | Value |", |
| "|--------|-------|", |
| f"| Duration | {duration:.1f}s |", |
| f"| Words | {word_count} |", |
| f"| Speaking Rate | {wpm:.0f} wpm |", |
| f"| Stutter Events | {total_stutters} |", |
| f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |", |
| f"| Severity | **{severity}** |", |
| "", |
| "### Stutter Counts", |
| "", |
| ] |
| for label in STUTTER_LABELS: |
| c = counts[label] |
| bar = "X" * min(c, 20) |
| icon = "!" if c > 0 else "o" |
| summary_lines.append(f"- {icon} **{label}**: {c} {bar}") |
|
|
| summary_md = "\n".join(summary_lines) |
|
|
| tl_lines = ["| Time | Detected |", "|------|----------|"] |
| for row in timeline_rows: |
| tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |") |
| timeline_md = "\n".join(tl_lines) |
|
|
| recs = ["## Recommendations\n"] |
| if severity in ("Very Mild", "Mild"): |
| recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.") |
| elif severity == "Moderate": |
| recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.") |
| else: |
| recs.append("- Professional speech-language pathology evaluation is strongly recommended.") |
|
|
| dominant = max(counts, key=counts.get) |
| if counts[dominant] > 0: |
| recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}") |
|
|
| if wpm > 180: |
| recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.") |
|
|
| recs.append("\n### Stutter Type Definitions\n") |
| for label, desc in STUTTER_INFO.items(): |
| recs.append(f"- **{label}**: {desc}") |
|
|
| recs_md = "\n".join(recs) |
|
|
| progress(1.0, desc="Done!") |
| return summary_md, transcription, timeline_md, recs_md |
|
|
|
|
| CUSTOM_CSS = """ |
| .gradio-container { max-width: 960px !important; } |
| .gr-button-primary { background: #0f766e !important; } |
| """ |
|
|
| with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: |
|
|
| gr.Markdown( |
| """ |
| # Speech Fluency Analysis |
| Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection) |
| and **Whisper** (transcription). |
| |
| Supported formats: **WAV, MP3, M4A, FLAC, OGG** |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| audio_in = gr.Audio(label="Upload Audio", type="filepath") |
| threshold = gr.Slider( |
| 0.3, 0.7, value=0.5, step=0.05, |
| label="Detection Threshold", |
| info="Lower = more sensitive, Higher = more strict", |
| ) |
| btn = gr.Button("Analyze", variant="primary", size="lg") |
|
|
| with gr.Column(scale=2): |
| summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Transcription"): |
| trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False) |
| with gr.TabItem("Timeline"): |
| timeline_out = gr.Markdown() |
| with gr.TabItem("Recommendations"): |
| recs_out = gr.Markdown() |
|
|
| gr.Markdown( |
| "---\n*Disclaimer: AI-assisted analysis for clinical support only. " |
| "Consult a qualified Speech-Language Pathologist for diagnosis.*" |
| ) |
|
|
| btn.click( |
| fn=analyze_audio, |
| inputs=[audio_in, threshold], |
| outputs=[summary_out, trans_out, timeline_out, recs_out], |
| show_progress="full", |
| ) |
|
|
| print("Loading models at startup ...") |
| load_models() |
|
|
| print("Launching Gradio ...") |
| demo.queue() |
| demo.launch(ssr_mode=False) |
|
|