throgletworld's picture
Upload 3 files
ba86bda verified
"""
Speech Fluency Analysis - Hugging Face Gradio App
WavLM stutter detection + Whisper transcription.
"""
import os
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import gradio as gr
from datetime import datetime
from transformers import WavLMModel
STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"]
STUTTER_INFO = {
"Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')",
"Block": "Complete stoppage of airflow/sound with tension",
"SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')",
"WordRep": "Whole word repetition (e.g. 'I-I-I want')",
"Interjection": "Filler words like 'um', 'uh', 'like'",
}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class WaveLmStutterClassification(nn.Module):
def __init__(self, num_labels=5):
super().__init__()
self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
self.hidden_size = self.wavlm.config.hidden_size
for p in self.wavlm.parameters():
p.requires_grad = False
self.classifier = nn.Linear(self.hidden_size, num_labels)
def forward(self, x, attention_mask=None):
h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state
return self.classifier(h.mean(dim=1))
wavlm_model = None
whisper_model = None
models_loaded = False
def load_models():
"""Load WavLM checkpoint and Whisper once."""
global wavlm_model, whisper_model, models_loaded
if models_loaded:
return True
print("Loading WavLM ...")
wavlm_model = WaveLmStutterClassification(num_labels=5)
ckpt = "wavlm_stutter_classification_best.pth"
if os.path.exists(ckpt):
state = torch.load(ckpt, map_location=DEVICE, weights_only=False)
if isinstance(state, dict) and "model_state_dict" in state:
wavlm_model.load_state_dict(state["model_state_dict"])
else:
wavlm_model.load_state_dict(state)
wavlm_model.to(DEVICE).eval()
print("Loading Whisper ...")
import whisper
whisper_model = whisper.load_model("base", device=DEVICE)
models_loaded = True
print("Models ready.")
return True
# FFmpeg explained:
# torchaudio.load() uses FFmpeg under the hood as a system-level library to
# DECODE compressed audio formats (mp3, m4a, ogg, flac) into raw PCM samples.
# FFmpeg is a CLI/OS tool - torchaudio calls it via its C backend.
# The decoded PCM data is then wrapped into a torch.Tensor (the waveform).
#
# Pipeline: audio file -> FFmpeg decodes -> raw samples -> torch.Tensor
#
# packages.txt lists "ffmpeg" so HF Spaces installs it at OS level.
def load_audio(path):
"""Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg)."""
waveform, sr = torchaudio.load(path)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sr != 16000:
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
return waveform.squeeze(0), 16000
def analyze_chunk(chunk, threshold=0.5):
"""Run WavLM on a single chunk."""
with torch.no_grad():
logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE))
probs = torch.sigmoid(logits).cpu().numpy()[0]
detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs]))
return detected, prob_dict
def analyze_audio(audio_path, threshold, progress=gr.Progress()):
"""Main pipeline: chunk -> WavLM -> Whisper -> formatted results."""
if audio_path is None:
return "Upload an audio file first.", "", "", ""
if isinstance(audio_path, tuple):
import tempfile, soundfile as sf
sr, data = audio_path
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(tmp.name, data, sr)
audio_path = tmp.name
progress(0.05, desc="Loading models ...")
if not models_loaded and not load_models():
return "Failed to load models.", "", "", ""
progress(0.15, desc="Loading audio ...")
waveform, sr = load_audio(audio_path)
duration = len(waveform) / sr
progress(0.25, desc="Detecting stutters ...")
chunk_samples = 3 * sr
counts = {l: 0 for l in STUTTER_LABELS}
timeline_rows = []
total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples)
for i, start in enumerate(range(0, len(waveform), chunk_samples)):
progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...")
end = min(start + chunk_samples, len(waveform))
chunk = waveform[start:end]
if len(chunk) < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
detected, probs = analyze_chunk(chunk, threshold)
for label in detected:
counts[label] += 1
time_str = f"{start/sr:.1f}-{end/sr:.1f}s"
timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs})
progress(0.75, desc="Transcribing ...")
transcription = whisper_model.transcribe(audio_path).get("text", "").strip()
progress(0.90, desc="Building report ...")
total_stutters = sum(counts.values())
chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"])
stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0
word_count = len(transcription.split()) if transcription else 0
wpm = (word_count / duration) * 60 if duration > 0 else 0
severity = (
"Very Mild" if stutter_pct < 5 else
"Mild" if stutter_pct < 10 else
"Moderate" if stutter_pct < 20 else
"Severe" if stutter_pct < 30 else
"Very Severe"
)
summary_lines = [
"## Analysis Results\n",
"| Metric | Value |",
"|--------|-------|",
f"| Duration | {duration:.1f}s |",
f"| Words | {word_count} |",
f"| Speaking Rate | {wpm:.0f} wpm |",
f"| Stutter Events | {total_stutters} |",
f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |",
f"| Severity | **{severity}** |",
"",
"### Stutter Counts",
"",
]
for label in STUTTER_LABELS:
c = counts[label]
bar = "X" * min(c, 20)
icon = "!" if c > 0 else "o"
summary_lines.append(f"- {icon} **{label}**: {c} {bar}")
summary_md = "\n".join(summary_lines)
tl_lines = ["| Time | Detected |", "|------|----------|"]
for row in timeline_rows:
tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |")
timeline_md = "\n".join(tl_lines)
recs = ["## Recommendations\n"]
if severity in ("Very Mild", "Mild"):
recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.")
elif severity == "Moderate":
recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.")
else:
recs.append("- Professional speech-language pathology evaluation is strongly recommended.")
dominant = max(counts, key=counts.get)
if counts[dominant] > 0:
recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}")
if wpm > 180:
recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.")
recs.append("\n### Stutter Type Definitions\n")
for label, desc in STUTTER_INFO.items():
recs.append(f"- **{label}**: {desc}")
recs_md = "\n".join(recs)
progress(1.0, desc="Done!")
return summary_md, transcription, timeline_md, recs_md
CUSTOM_CSS = """
.gradio-container { max-width: 960px !important; }
.gr-button-primary { background: #0f766e !important; }
"""
with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Speech Fluency Analysis
Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection)
and **Whisper** (transcription).
Supported formats: **WAV, MP3, M4A, FLAC, OGG**
"""
)
with gr.Row():
with gr.Column(scale=1):
audio_in = gr.Audio(label="Upload Audio", type="filepath")
threshold = gr.Slider(
0.3, 0.7, value=0.5, step=0.05,
label="Detection Threshold",
info="Lower = more sensitive, Higher = more strict",
)
btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Column(scale=2):
summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*")
with gr.Tabs():
with gr.TabItem("Transcription"):
trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False)
with gr.TabItem("Timeline"):
timeline_out = gr.Markdown()
with gr.TabItem("Recommendations"):
recs_out = gr.Markdown()
gr.Markdown(
"---\n*Disclaimer: AI-assisted analysis for clinical support only. "
"Consult a qualified Speech-Language Pathologist for diagnosis.*"
)
btn.click(
fn=analyze_audio,
inputs=[audio_in, threshold],
outputs=[summary_out, trans_out, timeline_out, recs_out],
show_progress="full",
)
print("Loading models at startup ...")
load_models()
print("Launching Gradio ...")
demo.queue()
demo.launch(ssr_mode=False)