throgletworld commited on
Commit
fb9af37
Β·
verified Β·
1 Parent(s): 8dd20ec

Upload 3 files

Browse files
app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces - Gradio App for Stutter Analysis
3
+ =====================================================
4
+ This is a standalone Gradio app for deployment on Hugging Face Spaces.
5
+
6
+ To deploy:
7
+ 1. Create a new Space on huggingface.co/spaces
8
+ 2. Choose "Gradio" as SDK
9
+ 3. Upload this folder's contents
10
+ 4. Add your model checkpoint to the Space
11
+ """
12
+
13
+ import gradio as gr
14
+ import torch
15
+ import torchaudio
16
+ import tempfile
17
+ import os
18
+ import json
19
+ from datetime import datetime
20
+ from transformers import WavLMModel
21
+ import torch.nn as nn
22
+ import whisper
23
+
24
+ # ============================================================================
25
+ # MODEL DEFINITION (same as models/WaveLm_model.py)
26
+ # ============================================================================
27
+
28
+ class WaveLmStutterClassification(nn.Module):
29
+ def __init__(self, num_labels=5, freeze_encoder=True, unfreeze_last_n_layers=1):
30
+ super().__init__()
31
+ self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
32
+ self.hidden_size = self.wavlm.config.hidden_size
33
+
34
+ if freeze_encoder:
35
+ for param in self.wavlm.parameters():
36
+ param.requires_grad = False
37
+
38
+ if unfreeze_last_n_layers > 0:
39
+ for layer in self.wavlm.encoder.layers[-unfreeze_last_n_layers:]:
40
+ for param in layer.parameters():
41
+ param.requires_grad = True
42
+
43
+ self.classifier = nn.Sequential(
44
+ nn.Linear(self.hidden_size, 256),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.3),
47
+ nn.Linear(256, num_labels)
48
+ )
49
+ self.num_labels = num_labels
50
+
51
+ def forward(self, input_values, attention_mask=None):
52
+ outputs = self.wavlm(input_values, attention_mask=attention_mask)
53
+ hidden_states = outputs.last_hidden_state
54
+ pooled = hidden_states.mean(dim=1)
55
+ logits = self.classifier(pooled)
56
+ return logits
57
+
58
+ # ============================================================================
59
+ # STUTTER LABELS & DEFINITIONS
60
+ # ============================================================================
61
+
62
+ STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection']
63
+
64
+ STUTTER_DEFINITIONS = {
65
+ 'Prolongation': 'Sound stretched longer than normal (e.g., "Ssssssnake")',
66
+ 'Block': 'Complete stoppage of airflow/sound with tension',
67
+ 'SoundRep': 'Sound/syllable repetition (e.g., "B-b-b-ball")',
68
+ 'WordRep': 'Whole word repetition (e.g., "I-I-I want")',
69
+ 'Interjection': 'Filler words like "um", "uh", "like"'
70
+ }
71
+
72
+ SEVERITY_THRESHOLDS = {'very_mild': 5, 'mild': 10, 'moderate': 20, 'severe': 30}
73
+
74
+ # ============================================================================
75
+ # GLOBAL MODEL LOADING
76
+ # ============================================================================
77
+
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ wavlm_model = None
80
+ whisper_model = None
81
+
82
+ def load_models():
83
+ global wavlm_model, whisper_model
84
+
85
+ # Load WavLM
86
+ print("Loading WavLM model...")
87
+ wavlm_model = WaveLmStutterClassification(num_labels=5)
88
+
89
+ # Try to load checkpoint
90
+ checkpoint_path = "wavlm_stutter_classification_best.pth"
91
+ if os.path.exists(checkpoint_path):
92
+ checkpoint = torch.load(checkpoint_path, map_location=device)
93
+ wavlm_model.load_state_dict(checkpoint['model_state_dict'])
94
+ print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy")
95
+ else:
96
+ print("WARNING: No checkpoint found, using random weights")
97
+
98
+ wavlm_model.to(device)
99
+ wavlm_model.eval()
100
+
101
+ # Load Whisper
102
+ print("Loading Whisper model...")
103
+ whisper_model = whisper.load_model("base", device=device)
104
+
105
+ print("Models loaded!")
106
+
107
+ # ============================================================================
108
+ # ANALYSIS FUNCTIONS
109
+ # ============================================================================
110
+
111
+ def preprocess_audio(audio_path):
112
+ """Convert audio to 16kHz mono"""
113
+ waveform, sr = torchaudio.load(audio_path)
114
+
115
+ # Convert to mono
116
+ if waveform.shape[0] > 1:
117
+ waveform = waveform.mean(dim=0, keepdim=True)
118
+
119
+ # Resample to 16kHz
120
+ if sr != 16000:
121
+ resampler = torchaudio.transforms.Resample(sr, 16000)
122
+ waveform = resampler(waveform)
123
+
124
+ return waveform.squeeze(0), 16000
125
+
126
+ def chunk_audio(waveform, sr, chunk_sec=3.0):
127
+ """Split audio into chunks"""
128
+ chunk_samples = int(chunk_sec * sr)
129
+ chunks = []
130
+
131
+ for start in range(0, len(waveform), chunk_samples):
132
+ end = min(start + chunk_samples, len(waveform))
133
+ chunk = waveform[start:end]
134
+
135
+ # Pad if needed
136
+ if len(chunk) < chunk_samples:
137
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
138
+
139
+ chunks.append({
140
+ 'chunk': chunk,
141
+ 'start': start / sr,
142
+ 'end': end / sr
143
+ })
144
+
145
+ return chunks
146
+
147
+ def analyze_chunk(chunk_waveform, threshold=0.5):
148
+ """Run WavLM on a single chunk"""
149
+ with torch.no_grad():
150
+ input_tensor = chunk_waveform.unsqueeze(0).to(device)
151
+ logits = wavlm_model(input_tensor)
152
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
153
+
154
+ detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
155
+ probabilities = {STUTTER_LABELS[i]: float(probs[i]) for i in range(len(STUTTER_LABELS))}
156
+
157
+ return {'detected': detected, 'probabilities': probabilities}
158
+
159
+ def get_severity(word_stutter_rate):
160
+ """Calculate severity from word stutter rate"""
161
+ if word_stutter_rate < SEVERITY_THRESHOLDS['very_mild']:
162
+ return 'Very Mild', 1
163
+ elif word_stutter_rate < SEVERITY_THRESHOLDS['mild']:
164
+ return 'Mild', 2
165
+ elif word_stutter_rate < SEVERITY_THRESHOLDS['moderate']:
166
+ return 'Moderate', 3
167
+ elif word_stutter_rate < SEVERITY_THRESHOLDS['severe']:
168
+ return 'Severe', 4
169
+ else:
170
+ return 'Very Severe', 5
171
+
172
+ # ============================================================================
173
+ # MAIN ANALYSIS FUNCTION
174
+ # ============================================================================
175
+
176
+ def analyze_audio(audio_file, threshold=0.5):
177
+ """Main analysis function for Gradio"""
178
+
179
+ if wavlm_model is None:
180
+ load_models()
181
+
182
+ if audio_file is None:
183
+ return "Please upload an audio file", "", "", ""
184
+
185
+ try:
186
+ # Preprocess
187
+ waveform, sr = preprocess_audio(audio_file)
188
+ duration = len(waveform) / sr
189
+
190
+ # Chunk and analyze with WavLM
191
+ chunks = chunk_audio(waveform, sr)
192
+
193
+ stutter_counts = {label: 0 for label in STUTTER_LABELS}
194
+ timeline = []
195
+
196
+ for chunk_info in chunks:
197
+ result = analyze_chunk(chunk_info['chunk'], threshold)
198
+ for label in result['detected']:
199
+ stutter_counts[label] += 1
200
+
201
+ timeline.append({
202
+ 'time': f"{chunk_info['start']:.1f}s - {chunk_info['end']:.1f}s",
203
+ 'detected': ', '.join(result['detected']) if result['detected'] else 'Clear',
204
+ 'probs': result['probabilities']
205
+ })
206
+
207
+ # Transcribe with Whisper
208
+ whisper_result = whisper_model.transcribe(audio_file, word_timestamps=True)
209
+ transcription = whisper_result['text']
210
+
211
+ # Get word-level info
212
+ words = []
213
+ if 'segments' in whisper_result:
214
+ for seg in whisper_result['segments']:
215
+ if 'words' in seg:
216
+ words.extend(seg['words'])
217
+
218
+ # Map stutters to words
219
+ words_with_stutter = 0
220
+ annotated_words = []
221
+
222
+ for word_info in words:
223
+ word_start = word_info.get('start', 0)
224
+ word_end = word_info.get('end', 0)
225
+ word_text = word_info.get('word', '')
226
+
227
+ word_stutters = []
228
+ for chunk_info in chunks:
229
+ if word_start < chunk_info['end'] and word_end > chunk_info['start']:
230
+ result = analyze_chunk(chunk_info['chunk'], threshold)
231
+ word_stutters.extend(result['detected'])
232
+
233
+ word_stutters = list(set(word_stutters))
234
+ if word_stutters:
235
+ words_with_stutter += 1
236
+ annotated_words.append(f"**[{word_text}]**({', '.join(word_stutters)})")
237
+ else:
238
+ annotated_words.append(word_text)
239
+
240
+ # Calculate metrics
241
+ total_words = len(words) if words else 1
242
+ word_stutter_rate = (words_with_stutter / total_words) * 100
243
+ severity_label, severity_score = get_severity(word_stutter_rate)
244
+
245
+ # Format outputs
246
+ summary = f"""
247
+ ## πŸ“Š Analysis Summary
248
+
249
+ **Duration:** {duration:.1f} seconds
250
+ **Total Words:** {total_words}
251
+ **Words with Stutters:** {words_with_stutter} ({word_stutter_rate:.1f}%)
252
+
253
+ ### Severity: {severity_label} ({severity_score}/5)
254
+
255
+ ### Stutter Type Counts:
256
+ """
257
+ for label, count in stutter_counts.items():
258
+ if count > 0:
259
+ summary += f"- **{label}**: {count} occurrences\n"
260
+
261
+ # Annotated transcription
262
+ annotated_text = " ".join(annotated_words) if annotated_words else transcription
263
+
264
+ # Timeline
265
+ timeline_text = "| Time | Detected Stutters |\n|------|-------------------|\n"
266
+ for t in timeline[:15]: # Limit to 15 rows
267
+ timeline_text += f"| {t['time']} | {t['detected']} |\n"
268
+
269
+ # Definitions
270
+ definitions = "## πŸ“– Stutter Type Definitions\n\n"
271
+ for label, desc in STUTTER_DEFINITIONS.items():
272
+ definitions += f"**{label}:** {desc}\n\n"
273
+
274
+ return summary, annotated_text, timeline_text, definitions
275
+
276
+ except Exception as e:
277
+ return f"Error: {str(e)}", "", "", ""
278
+
279
+ # ============================================================================
280
+ # GRADIO INTERFACE
281
+ # ============================================================================
282
+
283
+ with gr.Blocks(title="πŸŽ™οΈ Stutter Analysis", theme=gr.themes.Soft()) as demo:
284
+ gr.Markdown("""
285
+ # πŸŽ™οΈ Speech Fluency Analysis System
286
+
287
+ Upload an audio file to analyze stuttering patterns using AI.
288
+
289
+ **Supported formats:** WAV, MP3, M4A, FLAC
290
+ """)
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=1):
294
+ audio_input = gr.Audio(
295
+ label="Upload Audio",
296
+ type="filepath",
297
+ sources=["upload", "microphone"]
298
+ )
299
+ threshold_slider = gr.Slider(
300
+ minimum=0.3,
301
+ maximum=0.7,
302
+ value=0.5,
303
+ step=0.05,
304
+ label="Detection Threshold",
305
+ info="Lower = more sensitive, Higher = more conservative"
306
+ )
307
+ analyze_btn = gr.Button("πŸ” Analyze Speech", variant="primary")
308
+
309
+ with gr.Column(scale=2):
310
+ summary_output = gr.Markdown(label="Summary")
311
+
312
+ with gr.Tabs():
313
+ with gr.Tab("πŸ“ Transcription"):
314
+ transcription_output = gr.Markdown(label="Annotated Transcription")
315
+
316
+ with gr.Tab("πŸ“ˆ Timeline"):
317
+ timeline_output = gr.Markdown(label="Timeline Analysis")
318
+
319
+ with gr.Tab("πŸ“– Definitions"):
320
+ definitions_output = gr.Markdown(label="Stutter Definitions")
321
+
322
+ analyze_btn.click(
323
+ fn=analyze_audio,
324
+ inputs=[audio_input, threshold_slider],
325
+ outputs=[summary_output, transcription_output, timeline_output, definitions_output]
326
+ )
327
+
328
+ gr.Markdown("""
329
+ ---
330
+ **Disclaimer:** This tool is for educational/research purposes.
331
+ Consult a qualified speech-language pathologist for clinical diagnosis.
332
+
333
+ Built with WavLM + Whisper | [GitHub](https://github.com/abhicodes-here2001/Multimodal-stuttering-analysis)
334
+ """)
335
+
336
+ # Load models on startup
337
+ load_models()
338
+
339
+ if __name__ == "__main__":
340
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Requirements
2
+ # For Gradio deployment
3
+
4
+ torch>=2.0.0
5
+ torchaudio>=2.0.0
6
+ transformers>=4.30.0
7
+ gradio>=4.0.0
8
+ openai-whisper>=20231117
9
+ numpy>=1.24.0
wavlm_stutter_classification_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b98f4e50fa40a0cd43602858d77bff78692a68b14fb6bb7144b5d2a12155071b
3
+ size 377646731