throgletworld commited on
Commit
e792433
·
verified ·
1 Parent(s): b8adbd4

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +131 -295
  2. requirements.txt +0 -3
app.py CHANGED
@@ -1,48 +1,21 @@
1
- """
2
- Hugging Face Spaces - Gradio App for Stutter Analysis
3
- =====================================================
4
- This is a standalone Gradio app for deployment on Hugging Face Spaces.
5
-
6
- To deploy:
7
- 1. Create a new Space on huggingface.co/spaces
8
- 2. Choose "Gradio" as SDK
9
- 3. Upload this folder's contents
10
- 4. Add your model checkpoint to the Space
11
- """
12
-
13
  import gradio as gr
14
  import torch
15
- import torchaudio
16
- import tempfile
17
  import os
18
- import json
19
- import soundfile as sf
20
- import librosa
21
  from datetime import datetime
22
  from transformers import WavLMModel
23
  import torch.nn as nn
24
- import whisper
25
 
26
- # ============================================================================
27
- # MODEL DEFINITION (same as models/WaveLm_model.py)
28
- # ============================================================================
29
 
30
  class WaveLmStutterClassification(nn.Module):
31
- def __init__(self, num_labels=5, freeze_encoder=True, unfreeze_last_n_layers=1):
32
  super().__init__()
33
  self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
34
  self.hidden_size = self.wavlm.config.hidden_size
35
-
36
- if freeze_encoder:
37
- for param in self.wavlm.parameters():
38
- param.requires_grad = False
39
-
40
- if unfreeze_last_n_layers > 0:
41
- for layer in self.wavlm.encoder.layers[-unfreeze_last_n_layers:]:
42
- for param in layer.parameters():
43
- param.requires_grad = True
44
-
45
- # Single linear layer to match the trained checkpoint
46
  self.classifier = nn.Linear(self.hidden_size, num_labels)
47
  self.num_labels = num_labels
48
 
@@ -53,310 +26,173 @@ class WaveLmStutterClassification(nn.Module):
53
  logits = self.classifier(pooled)
54
  return logits
55
 
56
- # ============================================================================
57
- # STUTTER LABELS & DEFINITIONS
58
- # ============================================================================
59
-
60
  STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection']
61
 
62
  STUTTER_DEFINITIONS = {
63
- 'Prolongation': 'Sound stretched longer than normal (e.g., "Ssssssnake")',
64
- 'Block': 'Complete stoppage of airflow/sound with tension',
65
- 'SoundRep': 'Sound/syllable repetition (e.g., "B-b-b-ball")',
66
- 'WordRep': 'Whole word repetition (e.g., "I-I-I want")',
67
- 'Interjection': 'Filler words like "um", "uh", "like"'
68
  }
69
 
70
- SEVERITY_THRESHOLDS = {'very_mild': 5, 'mild': 10, 'moderate': 20, 'severe': 30}
71
-
72
- # ============================================================================
73
- # GLOBAL MODEL LOADING
74
- # ============================================================================
75
-
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
77
  wavlm_model = None
78
  whisper_model = None
 
79
 
80
  def load_models():
81
- global wavlm_model, whisper_model
82
-
83
- # Load WavLM
84
- print("Loading WavLM model...")
85
- wavlm_model = WaveLmStutterClassification(num_labels=5)
86
-
87
- # Try to load checkpoint
88
- checkpoint_path = "wavlm_stutter_classification_best.pth"
89
- if os.path.exists(checkpoint_path):
90
- checkpoint = torch.load(checkpoint_path, map_location=device)
91
- # Handle both formats: direct state_dict OR wrapped in 'model_state_dict'
92
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
93
- wavlm_model.load_state_dict(checkpoint['model_state_dict'])
94
- print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy")
95
- else:
96
- # Direct state_dict (how train_waveLM.py saves it)
97
- wavlm_model.load_state_dict(checkpoint)
98
- print("Loaded checkpoint (direct state_dict format)")
99
- else:
100
- print("WARNING: No checkpoint found, using random weights")
101
-
102
- wavlm_model.to(device)
103
- wavlm_model.eval()
104
-
105
- # Load Whisper
106
- print("Loading Whisper model...")
107
- whisper_model = whisper.load_model("base", device=device)
108
-
109
- print("Models loaded!")
110
-
111
- # ============================================================================
112
- # ANALYSIS FUNCTIONS
113
- # ============================================================================
114
-
115
- def preprocess_audio(audio_path):
116
- """Convert audio to 16kHz mono using soundfile or librosa."""
117
  try:
118
- # Try loading with soundfile first (faster)
119
- waveform_np, sr = sf.read(audio_path, dtype='float32')
120
-
121
- # Handle multi-channel (soundfile returns (samples, channels))
122
- if len(waveform_np.shape) > 1:
123
- waveform_np = waveform_np.mean(axis=1)
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
- print(f"Soundfile load failed, trying librosa: {e}")
127
- # Fallback to librosa (handles mp3/m4a better via ffmpeg)
128
- # librosa loads as mono by default, and we can force sr=16000 here
129
- waveform_np, sr = librosa.load(audio_path, sr=16000, mono=True)
130
-
131
- # Convert to tensor
132
- waveform = torch.from_numpy(waveform_np).float()
133
-
134
- # Resample if needed (only if soundfile was used and sr != 16000)
135
- # If librosa was used, it's already 16000
136
- if sr != 16000:
137
- resampler = torchaudio.transforms.Resample(sr, 16000)
138
- waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
139
-
140
- return waveform, 16000
141
 
142
- def chunk_audio(waveform, sr, chunk_sec=3.0):
143
- """Split audio into chunks"""
144
- chunk_samples = int(chunk_sec * sr)
145
- chunks = []
146
-
147
- for start in range(0, len(waveform), chunk_samples):
148
- end = min(start + chunk_samples, len(waveform))
149
- chunk = waveform[start:end]
150
-
151
- # Pad if needed
152
- if len(chunk) < chunk_samples:
153
- chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
154
-
155
- chunks.append({
156
- 'chunk': chunk,
157
- 'start': start / sr,
158
- 'end': end / sr
159
- })
160
-
161
- return chunks
 
162
 
163
- def analyze_chunk(chunk_waveform, threshold=0.5):
164
- """Run WavLM on a single chunk"""
165
  with torch.no_grad():
166
- input_tensor = chunk_waveform.unsqueeze(0).to(device)
167
- logits = wavlm_model(input_tensor)
168
  probs = torch.sigmoid(logits).cpu().numpy()[0]
169
-
170
  detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
171
- probabilities = {STUTTER_LABELS[i]: float(probs[i]) for i in range(len(STUTTER_LABELS))}
172
-
173
- return {'detected': detected, 'probabilities': probabilities}
174
 
175
- def get_severity(word_stutter_rate):
176
- """Calculate severity from word stutter rate"""
177
- if word_stutter_rate < SEVERITY_THRESHOLDS['very_mild']:
178
- return 'Very Mild', 1
179
- elif word_stutter_rate < SEVERITY_THRESHOLDS['mild']:
180
- return 'Mild', 2
181
- elif word_stutter_rate < SEVERITY_THRESHOLDS['moderate']:
182
- return 'Moderate', 3
183
- elif word_stutter_rate < SEVERITY_THRESHOLDS['severe']:
184
- return 'Severe', 4
185
- else:
186
- return 'Very Severe', 5
187
-
188
- # ============================================================================
189
- # MAIN ANALYSIS FUNCTION
190
- # ============================================================================
191
-
192
- def analyze_audio(audio_file, threshold=0.5):
193
- """Main analysis function for Gradio"""
194
 
195
- if wavlm_model is None:
196
- load_models()
 
 
 
 
 
197
 
198
- if audio_file is None:
199
- return "⚠️ Please upload an audio file", "", "", ""
 
 
200
 
201
  try:
202
- print(f"Starting analysis of: {audio_file}")
 
203
 
204
- # Preprocess
205
- waveform, sr = preprocess_audio(audio_file)
206
  duration = len(waveform) / sr
207
- print(f"Audio preprocessed: {duration:.1f}s, {sr}Hz")
208
-
209
- # Chunk and analyze with WavLM
210
- chunks = chunk_audio(waveform, sr)
211
 
212
- stutter_counts = {label: 0 for label in STUTTER_LABELS}
 
213
  timeline = []
214
 
215
- for chunk_info in chunks:
216
- result = analyze_chunk(chunk_info['chunk'], threshold)
217
- for label in result['detected']:
218
- stutter_counts[label] += 1
 
219
 
220
- timeline.append({
221
- 'time': f"{chunk_info['start']:.1f}s - {chunk_info['end']:.1f}s",
222
- 'detected': ', '.join(result['detected']) if result['detected'] else 'Clear',
223
- 'probs': result['probabilities']
224
- })
225
 
226
- # Transcribe with Whisper
227
- whisper_result = whisper_model.transcribe(audio_file, word_timestamps=True)
228
- transcription = whisper_result['text']
229
 
230
- # Get word-level info
231
- words = []
232
- if 'segments' in whisper_result:
233
- for seg in whisper_result['segments']:
234
- if 'words' in seg:
235
- words.extend(seg['words'])
236
 
237
- # Map stutters to words
238
- words_with_stutter = 0
239
- annotated_words = []
240
 
241
- for word_info in words:
242
- word_start = word_info.get('start', 0)
243
- word_end = word_info.get('end', 0)
244
- word_text = word_info.get('word', '')
245
-
246
- word_stutters = []
247
- for chunk_info in chunks:
248
- if word_start < chunk_info['end'] and word_end > chunk_info['start']:
249
- result = analyze_chunk(chunk_info['chunk'], threshold)
250
- word_stutters.extend(result['detected'])
251
-
252
- word_stutters = list(set(word_stutters))
253
- if word_stutters:
254
- words_with_stutter += 1
255
- annotated_words.append(f"**[{word_text}]**({', '.join(word_stutters)})")
256
- else:
257
- annotated_words.append(word_text)
258
 
259
- # Calculate metrics
260
- total_words = len(words) if words else 1
261
- word_stutter_rate = (words_with_stutter / total_words) * 100
262
- severity_label, severity_score = get_severity(word_stutter_rate)
263
-
264
- # Format outputs
265
- summary = f"""
266
- ## 📊 Analysis Summary
267
-
268
- **Duration:** {duration:.1f} seconds
269
- **Total Words:** {total_words}
270
- **Words with Stutters:** {words_with_stutter} ({word_stutter_rate:.1f}%)
271
-
272
- ### Severity: {severity_label} ({severity_score}/5)
273
-
274
- ### Stutter Type Counts:
275
- """
276
- for label, count in stutter_counts.items():
277
- if count > 0:
278
- summary += f"- **{label}**: {count} occurrences\n"
279
-
280
- # Annotated transcription
281
- annotated_text = " ".join(annotated_words) if annotated_words else transcription
282
-
283
- # Timeline
284
- timeline_text = "| Time | Detected Stutters |\n|------|-------------------|\n"
285
- for t in timeline[:15]: # Limit to 15 rows
286
- timeline_text += f"| {t['time']} | {t['detected']} |\n"
287
-
288
- # Definitions
289
- definitions = "## 📖 Stutter Type Definitions\n\n"
290
- for label, desc in STUTTER_DEFINITIONS.items():
291
- definitions += f"**{label}:** {desc}\n\n"
292
-
293
- return summary, annotated_text, timeline_text, definitions
294
 
295
  except Exception as e:
296
- import traceback
297
- error_trace = traceback.format_exc()
298
- print(f"Error in analyze_audio: {error_trace}")
299
- return f"❌ Error: {str(e)}\n\n```\n{error_trace}\n```", "", "", ""
300
 
301
- # ============================================================================
302
- # GRADIO INTERFACE
303
- # ============================================================================
304
 
305
- with gr.Blocks(title="🎙️ Stutter Analysis") as demo:
306
- gr.Markdown("""
307
- # 🎙️ Speech Fluency Analysis System
308
-
309
- Upload an audio file to analyze stuttering patterns using AI.
310
-
311
- **Supported formats:** WAV, MP3, M4A, FLAC
312
- """)
313
 
314
  with gr.Row():
315
- with gr.Column(scale=1):
316
- audio_input = gr.Audio(
317
- label="Upload Audio",
318
- type="filepath",
319
- sources=["upload", "microphone"]
320
- )
321
- threshold_slider = gr.Slider(
322
- minimum=0.3,
323
- maximum=0.7,
324
- value=0.5,
325
- step=0.05,
326
- label="Detection Threshold",
327
- info="Lower = more sensitive, Higher = more conservative"
328
- )
329
- analyze_btn = gr.Button("🔍 Analyze Speech", variant="primary")
330
-
331
- with gr.Column(scale=2):
332
- summary_output = gr.Markdown(label="Summary")
333
 
334
  with gr.Tabs():
335
- with gr.Tab("📝 Transcription"):
336
- transcription_output = gr.Markdown(label="Annotated Transcription")
337
-
338
- with gr.Tab("📈 Timeline"):
339
- timeline_output = gr.Markdown(label="Timeline Analysis")
340
-
341
- with gr.Tab("📖 Definitions"):
342
- definitions_output = gr.Markdown(label="Stutter Definitions")
343
-
344
- analyze_btn.click(
345
- fn=analyze_audio,
346
- inputs=[audio_input, threshold_slider],
347
- outputs=[summary_output, transcription_output, timeline_output, definitions_output]
348
- )
349
-
350
- gr.Markdown("""
351
- ---
352
- **Disclaimer:** This tool is for educational/research purposes.
353
- Consult a qualified speech-language pathologist for clinical diagnosis.
354
 
355
- Built with WavLM + Whisper | [GitHub](https://github.com/abhicodes-here2001/Multimodal-stuttering-analysis)
356
- """)
357
 
358
- # Load models on startup
359
  load_models()
360
 
361
- if __name__ == "__main__":
362
- demo.launch(theme=gr.themes.Soft())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
 
4
  import os
5
+ import traceback
 
 
6
  from datetime import datetime
7
  from transformers import WavLMModel
8
  import torch.nn as nn
 
9
 
10
+ print(f"APP STARTUP: {datetime.now()}")
 
 
11
 
12
  class WaveLmStutterClassification(nn.Module):
13
+ def __init__(self, num_labels=5):
14
  super().__init__()
15
  self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
16
  self.hidden_size = self.wavlm.config.hidden_size
17
+ for param in self.wavlm.parameters():
18
+ param.requires_grad = False
 
 
 
 
 
 
 
 
 
19
  self.classifier = nn.Linear(self.hidden_size, num_labels)
20
  self.num_labels = num_labels
21
 
 
26
  logits = self.classifier(pooled)
27
  return logits
28
 
 
 
 
 
29
  STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection']
30
 
31
  STUTTER_DEFINITIONS = {
32
+ 'Prolongation': 'Sound stretched longer than normal',
33
+ 'Block': 'Complete stoppage of airflow/sound',
34
+ 'SoundRep': 'Sound/syllable repetition',
35
+ 'WordRep': 'Whole word repetition',
36
+ 'Interjection': 'Filler words like um, uh'
37
  }
38
 
 
 
 
 
 
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(f"Device: {device}")
41
+
42
  wavlm_model = None
43
  whisper_model = None
44
+ models_loaded = False
45
 
46
  def load_models():
47
+ global wavlm_model, whisper_model, models_loaded
48
+ if models_loaded:
49
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ print("Loading WavLM...")
52
+ wavlm_model = WaveLmStutterClassification(num_labels=5)
53
+ checkpoint_path = "wavlm_stutter_classification_best.pth"
54
+ if os.path.exists(checkpoint_path):
55
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
56
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
57
+ wavlm_model.load_state_dict(checkpoint['model_state_dict'])
58
+ else:
59
+ wavlm_model.load_state_dict(checkpoint)
60
+ print("Checkpoint loaded!")
61
+ wavlm_model.to(device)
62
+ wavlm_model.eval()
63
+
64
+ print("Loading Whisper...")
65
+ import whisper
66
+ whisper_model = whisper.load_model("base", device=device)
67
+
68
+ models_loaded = True
69
+ print("Models loaded!")
70
+ return True
71
  except Exception as e:
72
+ print(f"Model loading error: {e}")
73
+ traceback.print_exc()
74
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def load_audio(audio_path):
77
+ print(f"Loading: {audio_path}")
78
+ try:
79
+ import librosa
80
+ waveform, sr = librosa.load(audio_path, sr=16000, mono=True)
81
+ return torch.from_numpy(waveform).float(), 16000
82
+ except Exception as e:
83
+ print(f"librosa error: {e}")
84
+ try:
85
+ import soundfile as sf
86
+ waveform, sr = sf.read(audio_path, dtype='float32')
87
+ if len(waveform.shape) > 1:
88
+ waveform = waveform.mean(axis=1)
89
+ waveform = torch.from_numpy(waveform).float()
90
+ if sr != 16000:
91
+ import torchaudio
92
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0)
93
+ return waveform, 16000
94
+ except Exception as e:
95
+ print(f"soundfile error: {e}")
96
+ raise Exception("Could not load audio")
97
 
98
+ def analyze_chunk(chunk_tensor, threshold=0.5):
 
99
  with torch.no_grad():
100
+ logits = wavlm_model(chunk_tensor.unsqueeze(0).to(device))
 
101
  probs = torch.sigmoid(logits).cpu().numpy()[0]
 
102
  detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
103
+ return detected, dict(zip(STUTTER_LABELS, probs.tolist()))
 
 
104
 
105
+ def analyze_audio(audio_input, threshold):
106
+ print(f"\n=== ANALYZE CLICKED ===")
107
+ print(f"Input: {audio_input}, Type: {type(audio_input)}, Threshold: {threshold}")
108
+
109
+ if audio_input is None:
110
+ return "Please upload an audio file first!", "", "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ audio_path = audio_input
113
+ if isinstance(audio_input, tuple):
114
+ import tempfile, soundfile as sf
115
+ sr, data = audio_input
116
+ f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
117
+ sf.write(f.name, data, sr)
118
+ audio_path = f.name
119
 
120
+ if not os.path.exists(audio_path):
121
+ return f"File not found: {audio_path}", "", "", ""
122
+
123
+ print(f"File: {audio_path}, Size: {os.path.getsize(audio_path)}")
124
 
125
  try:
126
+ if not models_loaded and not load_models():
127
+ return "Failed to load models", "", "", ""
128
 
129
+ waveform, sr = load_audio(audio_path)
 
130
  duration = len(waveform) / sr
131
+ print(f"Duration: {duration:.1f}s")
 
 
 
132
 
133
+ chunk_samples = int(3.0 * sr)
134
+ stutter_counts = {l: 0 for l in STUTTER_LABELS}
135
  timeline = []
136
 
137
+ for start in range(0, len(waveform), chunk_samples):
138
+ end = min(start + chunk_samples, len(waveform))
139
+ chunk = waveform[start:end]
140
+ if len(chunk) < chunk_samples:
141
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
142
 
143
+ detected, _ = analyze_chunk(chunk, threshold)
144
+ for l in detected:
145
+ stutter_counts[l] += 1
146
+ timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Clear"]})
 
147
 
148
+ print("Running Whisper...")
149
+ transcription = whisper_model.transcribe(audio_path).get('text', '')
 
150
 
151
+ total = sum(stutter_counts.values())
152
+ summary = f"## Analysis Complete\n\n**Duration:** {duration:.1f}s\n**Stutters:** {total}\n\n"
153
+ for l, c in stutter_counts.items():
154
+ summary += f"- {l}: {c}\n"
 
 
155
 
156
+ timeline_md = "| Time | Detected |\n|---|---|\n"
157
+ for t in timeline[:15]:
158
+ timeline_md += f"| {t['time']} | {', '.join(t['detected'])} |\n"
159
 
160
+ defs = "\n".join([f"**{k}:** {v}" for k, v in STUTTER_DEFINITIONS.items()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ print("Done!")
163
+ return summary, transcription, timeline_md, defs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  except Exception as e:
166
+ print(f"Error: {e}")
167
+ traceback.print_exc()
168
+ return f"Error: {e}\n\n{traceback.format_exc()}", "", "", ""
 
169
 
170
+ print("Building UI...")
 
 
171
 
172
+ with gr.Blocks(title="Stutter Analysis") as demo:
173
+ gr.Markdown("# Speech Fluency Analysis\nUpload audio to analyze stuttering.")
 
 
 
 
 
 
174
 
175
  with gr.Row():
176
+ with gr.Column():
177
+ audio = gr.Audio(label="Upload Audio", type="filepath")
178
+ threshold = gr.Slider(0.3, 0.7, 0.5, label="Threshold")
179
+ btn = gr.Button("Analyze", variant="primary")
180
+ with gr.Column():
181
+ summary = gr.Markdown(value="Upload audio and click Analyze")
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  with gr.Tabs():
184
+ with gr.TabItem("Transcription"):
185
+ trans = gr.Markdown()
186
+ with gr.TabItem("Timeline"):
187
+ timeline = gr.Markdown()
188
+ with gr.TabItem("Definitions"):
189
+ defs = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ btn.click(analyze_audio, [audio, threshold], [summary, trans, timeline, defs])
 
192
 
193
+ print("Loading models...")
194
  load_models()
195
 
196
+ print("Launching...")
197
+ demo.queue()
198
+ demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,3 @@
1
- # Hugging Face Spaces Requirements
2
- # For Gradio deployment
3
-
4
  torch>=2.0.0
5
  torchaudio>=2.0.0
6
  transformers>=4.30.0
 
 
 
 
1
  torch>=2.0.0
2
  torchaudio>=2.0.0
3
  transformers>=4.30.0