""" Speech Fluency Analysis - Hugging Face Gradio App WavLM stutter detection + Whisper transcription. """ import os import torch import torch.nn as nn import torchaudio import numpy as np import gradio as gr from datetime import datetime from transformers import WavLMModel STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"] STUTTER_INFO = { "Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')", "Block": "Complete stoppage of airflow/sound with tension", "SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')", "WordRep": "Whole word repetition (e.g. 'I-I-I want')", "Interjection": "Filler words like 'um', 'uh', 'like'", } DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class WaveLmStutterClassification(nn.Module): def __init__(self, num_labels=5): super().__init__() self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") self.hidden_size = self.wavlm.config.hidden_size for p in self.wavlm.parameters(): p.requires_grad = False self.classifier = nn.Linear(self.hidden_size, num_labels) def forward(self, x, attention_mask=None): h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state return self.classifier(h.mean(dim=1)) wavlm_model = None whisper_model = None models_loaded = False def load_models(): """Load WavLM checkpoint and Whisper once.""" global wavlm_model, whisper_model, models_loaded if models_loaded: return True print("Loading WavLM ...") wavlm_model = WaveLmStutterClassification(num_labels=5) ckpt = "wavlm_stutter_classification_best.pth" if os.path.exists(ckpt): state = torch.load(ckpt, map_location=DEVICE, weights_only=False) if isinstance(state, dict) and "model_state_dict" in state: wavlm_model.load_state_dict(state["model_state_dict"]) else: wavlm_model.load_state_dict(state) wavlm_model.to(DEVICE).eval() print("Loading Whisper ...") import whisper whisper_model = whisper.load_model("base", device=DEVICE) models_loaded = True print("Models ready.") return True # FFmpeg explained: # torchaudio.load() uses FFmpeg under the hood as a system-level library to # DECODE compressed audio formats (mp3, m4a, ogg, flac) into raw PCM samples. # FFmpeg is a CLI/OS tool - torchaudio calls it via its C backend. # The decoded PCM data is then wrapped into a torch.Tensor (the waveform). # # Pipeline: audio file -> FFmpeg decodes -> raw samples -> torch.Tensor # # packages.txt lists "ffmpeg" so HF Spaces installs it at OS level. def load_audio(path): """Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg).""" waveform, sr = torchaudio.load(path) if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != 16000: waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) return waveform.squeeze(0), 16000 def analyze_chunk(chunk, threshold=0.5): """Run WavLM on a single chunk.""" with torch.no_grad(): logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE)) probs = torch.sigmoid(logits).cpu().numpy()[0] detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold] prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs])) return detected, prob_dict def analyze_audio(audio_path, threshold, progress=gr.Progress()): """Main pipeline: chunk -> WavLM -> Whisper -> formatted results.""" if audio_path is None: return "Upload an audio file first.", "", "", "" if isinstance(audio_path, tuple): import tempfile, soundfile as sf sr, data = audio_path tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(tmp.name, data, sr) audio_path = tmp.name progress(0.05, desc="Loading models ...") if not models_loaded and not load_models(): return "Failed to load models.", "", "", "" progress(0.15, desc="Loading audio ...") waveform, sr = load_audio(audio_path) duration = len(waveform) / sr progress(0.25, desc="Detecting stutters ...") chunk_samples = 3 * sr counts = {l: 0 for l in STUTTER_LABELS} timeline_rows = [] total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples) for i, start in enumerate(range(0, len(waveform), chunk_samples)): progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...") end = min(start + chunk_samples, len(waveform)) chunk = waveform[start:end] if len(chunk) < chunk_samples: chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) detected, probs = analyze_chunk(chunk, threshold) for label in detected: counts[label] += 1 time_str = f"{start/sr:.1f}-{end/sr:.1f}s" timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs}) progress(0.75, desc="Transcribing ...") transcription = whisper_model.transcribe(audio_path).get("text", "").strip() progress(0.90, desc="Building report ...") total_stutters = sum(counts.values()) chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"]) stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0 word_count = len(transcription.split()) if transcription else 0 wpm = (word_count / duration) * 60 if duration > 0 else 0 severity = ( "Very Mild" if stutter_pct < 5 else "Mild" if stutter_pct < 10 else "Moderate" if stutter_pct < 20 else "Severe" if stutter_pct < 30 else "Very Severe" ) summary_lines = [ "## Analysis Results\n", "| Metric | Value |", "|--------|-------|", f"| Duration | {duration:.1f}s |", f"| Words | {word_count} |", f"| Speaking Rate | {wpm:.0f} wpm |", f"| Stutter Events | {total_stutters} |", f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |", f"| Severity | **{severity}** |", "", "### Stutter Counts", "", ] for label in STUTTER_LABELS: c = counts[label] bar = "X" * min(c, 20) icon = "!" if c > 0 else "o" summary_lines.append(f"- {icon} **{label}**: {c} {bar}") summary_md = "\n".join(summary_lines) tl_lines = ["| Time | Detected |", "|------|----------|"] for row in timeline_rows: tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |") timeline_md = "\n".join(tl_lines) recs = ["## Recommendations\n"] if severity in ("Very Mild", "Mild"): recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.") elif severity == "Moderate": recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.") else: recs.append("- Professional speech-language pathology evaluation is strongly recommended.") dominant = max(counts, key=counts.get) if counts[dominant] > 0: recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}") if wpm > 180: recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.") recs.append("\n### Stutter Type Definitions\n") for label, desc in STUTTER_INFO.items(): recs.append(f"- **{label}**: {desc}") recs_md = "\n".join(recs) progress(1.0, desc="Done!") return summary_md, transcription, timeline_md, recs_md CUSTOM_CSS = """ .gradio-container { max-width: 960px !important; } .gr-button-primary { background: #0f766e !important; } """ with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Speech Fluency Analysis Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection) and **Whisper** (transcription). Supported formats: **WAV, MP3, M4A, FLAC, OGG** """ ) with gr.Row(): with gr.Column(scale=1): audio_in = gr.Audio(label="Upload Audio", type="filepath") threshold = gr.Slider( 0.3, 0.7, value=0.5, step=0.05, label="Detection Threshold", info="Lower = more sensitive, Higher = more strict", ) btn = gr.Button("Analyze", variant="primary", size="lg") with gr.Column(scale=2): summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*") with gr.Tabs(): with gr.TabItem("Transcription"): trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False) with gr.TabItem("Timeline"): timeline_out = gr.Markdown() with gr.TabItem("Recommendations"): recs_out = gr.Markdown() gr.Markdown( "---\n*Disclaimer: AI-assisted analysis for clinical support only. " "Consult a qualified Speech-Language Pathologist for diagnosis.*" ) btn.click( fn=analyze_audio, inputs=[audio_in, threshold], outputs=[summary_out, trans_out, timeline_out, recs_out], show_progress="full", ) print("Loading models at startup ...") load_models() print("Launching Gradio ...") demo.queue() demo.launch(ssr_mode=False)