throgletworld commited on
Commit
e2ff604
Β·
verified Β·
1 Parent(s): 745f57f

Upload 4 files

Browse files
app.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from transformers import WavLMModel
8
+ import torch.nn as nn
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+ print(f"APP STARTUP: {datetime.now()}")
12
+
13
+ # =============================================================================
14
+ # WHY SIGMOID INSTEAD OF SOFTMAX? - A DETAILED EXPLANATION
15
+ # =============================================================================
16
+ """
17
+ MULTI-LABEL vs MULTI-CLASS CLASSIFICATION
18
+ ==========================================
19
+
20
+ Our stutter detection is a MULTI-LABEL problem:
21
+ - A single 3-second audio chunk can have MULTIPLE stutters simultaneously
22
+ - Example: Someone might have a "Block" AND a "SoundRep" in the same chunk
23
+ - Each of the 5 stutter types is INDEPENDENT of the others
24
+
25
+ SOFTMAX (❌ NOT suitable for us):
26
+ ---------------------------------
27
+ - Used for MULTI-CLASS problems where classes are MUTUALLY EXCLUSIVE
28
+ - Example: "Is this image a Cat OR a Dog?" (can't be both)
29
+ - Formula: softmax(x_i) = exp(x_i) / sum(exp(x_j)) for all j
30
+ - All probabilities MUST sum to 1.0
31
+ - Problem: If we used softmax and got [0.7, 0.1, 0.1, 0.05, 0.05]:
32
+ - It would say "70% Prolongation" but FORCE other classes to be low
33
+ - We couldn't detect multiple stutters in one chunk!
34
+
35
+ SIGMOID (βœ… CORRECT for us):
36
+ ----------------------------
37
+ - Used for MULTI-LABEL problems where classes are INDEPENDENT
38
+ - Each class gets its own independent probability (0 to 1)
39
+ - Formula: sigmoid(x) = 1 / (1 + exp(-x))
40
+ - Probabilities DON'T need to sum to 1
41
+ - Example output: [0.8, 0.7, 0.2, 0.1, 0.05]
42
+ - 80% chance of Prolongation
43
+ - 70% chance of Block
44
+ - Both can be detected simultaneously!
45
+
46
+ THE TRAINING & INFERENCE FLOW:
47
+ ==============================
48
+
49
+ TRAINING:
50
+ ---------
51
+ 1. Model outputs: LOGITS (raw scores from -∞ to +∞)
52
+ Example: [2.5, -3.0, 0.1, -1.5, -2.0]
53
+
54
+ 2. Loss Function: BCEWithLogitsLoss
55
+ - "WithLogits" means it applies Sigmoid INTERNALLY
56
+ - More numerically stable than separate Sigmoid + BCELoss
57
+ - Compares each prediction to each ground truth label independently
58
+
59
+ INFERENCE (this file):
60
+ ----------------------
61
+ 1. Model outputs: LOGITS (same as training)
62
+ Example: [2.5, -3.0, 0.1, -1.5, -2.0]
63
+
64
+ 2. We manually apply Sigmoid to convert to probabilities:
65
+ probs = torch.sigmoid(logits)
66
+ Result: [0.92, 0.05, 0.52, 0.18, 0.12]
67
+
68
+ 3. Apply threshold (e.g., 0.5) to each probability:
69
+ - 0.92 > 0.5 β†’ Prolongation DETECTED
70
+ - 0.05 < 0.5 β†’ Block NOT detected
71
+ - 0.52 > 0.5 β†’ SoundRep DETECTED
72
+ - etc.
73
+
74
+ 4. If NO stutters detected (all below threshold):
75
+ β†’ Label the chunk as "Fluent"
76
+
77
+ THRESHOLD EXPLAINED:
78
+ ====================
79
+ - Default: 0.5 (theoretically neutral, since sigmoid(0) = 0.5)
80
+ - Lower threshold (0.3-0.4): More SENSITIVE, catches more stutters, but more false positives
81
+ - Higher threshold (0.6-0.7): More STRICT, fewer false positives, but might miss subtle stutters
82
+ - The slider in the UI lets users adjust this based on their needs
83
+ - SAME threshold is applied to ALL 5 classes (simplest approach)
84
+ """
85
+
86
+ class WaveLmStutterClassification(nn.Module):
87
+ def __init__(self, num_labels=5):
88
+ super().__init__()
89
+ self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
90
+ self.hidden_size = self.wavlm.config.hidden_size
91
+ for param in self.wavlm.parameters():
92
+ param.requires_grad = False
93
+ self.classifier = nn.Linear(self.hidden_size, num_labels)
94
+ self.num_labels = num_labels
95
+
96
+ def forward(self, input_values, attention_mask=None):
97
+ outputs = self.wavlm(input_values, attention_mask=attention_mask)
98
+ hidden_states = outputs.last_hidden_state
99
+ pooled = hidden_states.mean(dim=1)
100
+ logits = self.classifier(pooled)
101
+ return logits
102
+
103
+ STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection']
104
+
105
+ STUTTER_DEFINITIONS = {
106
+ 'Prolongation': 'Sound stretched longer than normal',
107
+ 'Block': 'Complete stoppage of airflow/sound',
108
+ 'SoundRep': 'Sound/syllable repetition',
109
+ 'WordRep': 'Whole word repetition',
110
+ 'Interjection': 'Filler words like um, uh'
111
+ }
112
+
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+ print(f"Device: {device}")
115
+
116
+ wavlm_model = None
117
+ whisper_model = None
118
+ medgemma_model = None
119
+ medgemma_tokenizer = None
120
+ models_loaded = False
121
+
122
+ def load_models():
123
+ global wavlm_model, whisper_model, models_loaded, medgemma_model, medgemma_tokenizer
124
+ if models_loaded:
125
+ return True
126
+ try:
127
+ print("Loading WavLM...")
128
+ wavlm_model = WaveLmStutterClassification(num_labels=5)
129
+ checkpoint_path = "wavlm_stutter_classification_best.pth"
130
+ if os.path.exists(checkpoint_path):
131
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
132
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
133
+ wavlm_model.load_state_dict(checkpoint['model_state_dict'])
134
+ else:
135
+ wavlm_model.load_state_dict(checkpoint)
136
+ print("Checkpoint loaded!")
137
+ wavlm_model.to(device)
138
+ wavlm_model.eval()
139
+
140
+ print("Loading Whisper...")
141
+ import whisper
142
+ whisper_model = whisper.load_model("base", device=device)
143
+
144
+ # NOTE: We lazy load MedGemma only when requested to save startup time/VRAM
145
+ # or load it here if we have enough memory.
146
+ # For this demo, let's lazy load it in the generate function.
147
+
148
+ models_loaded = True
149
+ print("Models loaded!")
150
+ return True
151
+ except Exception as e:
152
+ print(f"Model loading error: {e}")
153
+ traceback.print_exc()
154
+ return False
155
+
156
+ def load_audio(audio_path):
157
+ print(f"Loading: {audio_path}")
158
+ try:
159
+ import librosa
160
+ waveform, sr = librosa.load(audio_path, sr=16000, mono=True)
161
+ return torch.from_numpy(waveform).float(), 16000
162
+ except Exception as e:
163
+ print(f"librosa error: {e}")
164
+ try:
165
+ import soundfile as sf
166
+ waveform, sr = sf.read(audio_path, dtype='float32')
167
+ if len(waveform.shape) > 1:
168
+ waveform = waveform.mean(axis=1)
169
+ waveform = torch.from_numpy(waveform).float()
170
+ if sr != 16000:
171
+ import torchaudio
172
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0)
173
+ return waveform, 16000
174
+ except Exception as e:
175
+ print(f"soundfile error: {e}")
176
+ raise Exception("Could not load audio")
177
+
178
+ # ============================================================================
179
+ # MEDGEMMA LOGIC
180
+ # ============================================================================
181
+
182
+ def load_medgemma_model():
183
+ global medgemma_model, medgemma_tokenizer
184
+ if medgemma_model is not None:
185
+ return True
186
+
187
+ print("Loading TxGemma 9B...")
188
+ try:
189
+ model_id = "google/txgemma-9b-predict"
190
+
191
+ # Use 4-bit quantization if CUDA is available to save VRAM
192
+ if device == "cuda":
193
+ from transformers import BitsAndBytesConfig
194
+ bnb_config = BitsAndBytesConfig(
195
+ load_in_4bit=True,
196
+ bnb_4bit_compute_dtype=torch.float16,
197
+ bnb_4bit_use_double_quant=True,
198
+ )
199
+ medgemma_model = AutoModelForCausalLM.from_pretrained(
200
+ model_id,
201
+ quantization_config=bnb_config,
202
+ device_map="auto"
203
+ )
204
+ else:
205
+ # CPU or MPS (load normally)
206
+ medgemma_model = AutoModelForCausalLM.from_pretrained(
207
+ model_id,
208
+ torch_dtype=torch.float32,
209
+ device_map="auto"
210
+ )
211
+
212
+ medgemma_tokenizer = AutoTokenizer.from_pretrained(model_id)
213
+ print("MedGemma Loaded!")
214
+ return True
215
+ except Exception as e:
216
+ print(f"Error loading MedGemma: {e}")
217
+ return False
218
+
219
+ def generate_medgemma_report(analysis_data, progress=gr.Progress()):
220
+ if not analysis_data:
221
+ return "⚠️ Please analyze an audio file first."
222
+
223
+ progress(0.1, desc="πŸ₯ Loading MedGemma...")
224
+ success = load_medgemma_model()
225
+ if not success:
226
+ return "❌ Failed to load MedGemma model. Please check logs."
227
+
228
+ progress(0.3, desc="πŸ“ Preparing clinical data...")
229
+
230
+ # Construct prompt
231
+ prompt = f"""You are an expert Speech-Language Pathologist (SLP) assistant.
232
+ Based on the following automated stuttering analysis data, generate a professional clinical report.
233
+
234
+ ## PATIENT INFORMATION
235
+ - Audio Duration: {analysis_data['duration']:.2f} seconds
236
+ - Total Words (Est): {analysis_data['word_count']}
237
+ - Speaking Rate: {analysis_data['speaking_rate']:.1f} words/min
238
+
239
+ ## TRANSCRIPTION
240
+ "{analysis_data['transcription']}"
241
+
242
+ ## STUTTERING ANALYSIS RESULTS
243
+ - Total Stutter Events: {analysis_data['total_stutters']}
244
+ - Stuttering Frequency: {analysis_data['frequency']:.1f}% of chunks affected
245
+
246
+ ## STUTTER TYPE DISTRIBUTION
247
+ {analysis_data['distribution_str']}
248
+
249
+ ---
250
+
251
+ Based on this data, please generate:
252
+ 1. **CLINICAL SUMMARY** (2-3 sentences): Overview of fluency patterns.
253
+ 2. **DETAILED FINDINGS**: Elaborate on types observed (Blocks, Prolongations, Repetitions).
254
+ 3. **RECOMMENDATIONS** (3 bullets): Evidence-based therapy suggestions.
255
+
256
+ Write in a professional, empathetic clinical tone suitable for patient records."""
257
+
258
+ messages = [
259
+ {"role": "system", "content": "You are an expert SLP assistant."},
260
+ {"role": "user", "content": prompt}
261
+ ]
262
+
263
+ progress(0.5, desc="🧠 Generating clinical narrative...")
264
+
265
+ try:
266
+ inputs = medgemma_tokenizer.apply_chat_template(
267
+ messages,
268
+ add_generation_prompt=True,
269
+ tokenize=True,
270
+ return_dict=True,
271
+ return_tensors="pt"
272
+ ).to(medgemma_model.device)
273
+
274
+ with torch.no_grad():
275
+ outputs = medgemma_model.generate(
276
+ **inputs,
277
+ max_new_tokens=800,
278
+ do_sample=True,
279
+ temperature=0.7
280
+ )
281
+
282
+ generated_text = medgemma_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
283
+ return generated_text
284
+
285
+ except Exception as e:
286
+ return f"Error gathering report: {str(e)}"
287
+
288
+ def analyze_chunk(chunk_tensor, threshold=0.5):
289
+ with torch.no_grad():
290
+ logits = wavlm_model(chunk_tensor.unsqueeze(0).to(device))
291
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
292
+ detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
293
+ return detected, dict(zip(STUTTER_LABELS, probs.tolist()))
294
+
295
+ def analyze_audio(audio_input, threshold, progress=gr.Progress()):
296
+ print(f"\n=== ANALYZE CLICKED ===")
297
+ print(f"Input: {audio_input}, Type: {type(audio_input)}, Threshold: {threshold}")
298
+
299
+ progress(0, desc="πŸ”„ Starting analysis...")
300
+
301
+ if audio_input is None:
302
+ return "⚠️ Please upload an audio file first!", "", "", ""
303
+
304
+ audio_path = audio_input
305
+ if isinstance(audio_input, tuple):
306
+ import tempfile, soundfile as sf
307
+ sr, data = audio_input
308
+ f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
309
+ sf.write(f.name, data, sr)
310
+ audio_path = f.name
311
+
312
+ if not os.path.exists(audio_path):
313
+ return f"File not found: {audio_path}", "", "", ""
314
+
315
+ print(f"File: {audio_path}, Size: {os.path.getsize(audio_path)}")
316
+
317
+ try:
318
+ progress(0.1, desc="πŸ”„ Loading models...")
319
+ if not models_loaded and not load_models():
320
+ return "❌ Failed to load models", "", "", ""
321
+
322
+ progress(0.2, desc="🎡 Loading audio file...")
323
+ waveform, sr = load_audio(audio_path)
324
+ duration = len(waveform) / sr
325
+ print(f"Duration: {duration:.1f}s")
326
+
327
+ progress(0.3, desc="βœ‚οΈ Splitting audio into chunks...")
328
+ chunk_samples = int(3.0 * sr)
329
+ stutter_counts = {l: 0 for l in STUTTER_LABELS}
330
+ timeline = []
331
+
332
+ total_chunks = (len(waveform) + chunk_samples - 1) // chunk_samples
333
+
334
+ for i, start in enumerate(range(0, len(waveform), chunk_samples)):
335
+ progress(0.3 + (0.4 * i / total_chunks), desc=f"πŸ” Analyzing chunk {i+1}/{total_chunks}...")
336
+
337
+ end = min(start + chunk_samples, len(waveform))
338
+ chunk = waveform[start:end]
339
+ if len(chunk) < chunk_samples:
340
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
341
+
342
+ detected, _ = analyze_chunk(chunk, threshold)
343
+ for l in detected:
344
+ stutter_counts[l] += 1
345
+ timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Fluent"]})
346
+
347
+ progress(0.75, desc="πŸ—£οΈ Transcribing with Whisper...")
348
+ print("Running Whisper...")
349
+ transcription = whisper_model.transcribe(audio_path).get('text', '')
350
+
351
+ progress(0.9, desc="πŸ“Š Generating report...")
352
+ total = sum(stutter_counts.values())
353
+ summary = f"## βœ… Analysis Complete!\n\n**Duration:** {duration:.1f}s\n**Total Stutters Detected:** {total}\n\n### Stutter Counts:\n"
354
+ for l, c in stutter_counts.items():
355
+ emoji = "πŸ”΄" if c > 0 else "βšͺ"
356
+ summary += f"- {emoji} **{l}**: {c}\n"
357
+
358
+ timeline_md = "| Time | Detected |\n|---|---|\n"
359
+ for t in timeline[:15]:
360
+ timeline_md += f"| {t['time']} | {', '.join(t['detected'])} |\n"
361
+ if len(timeline) > 15:
362
+ timeline_md += f"\n*...and {len(timeline) - 15} more chunks*"
363
+
364
+ defs = "## πŸ“– Stutter Type Definitions\n\n"
365
+ defs += "\n".join([f"**{k}:** {v}" for k, v in STUTTER_DEFINITIONS.items()])
366
+
367
+ # Create analysis data for MedGemma
368
+ analysis_data = {
369
+ 'duration': duration,
370
+ 'word_count': len(transcription.split()),
371
+ 'speaking_rate': (len(transcription.split())/duration) * 60 if duration > 0 else 0,
372
+ 'transcription': transcription,
373
+ 'total_stutters': total,
374
+ 'frequency': (sum(1 for t in timeline if "Fluent" not in t['detected']) / total_chunks) * 100 if total_chunks > 0 else 0,
375
+ 'distribution_str': "\n".join([f"- {k}: {v} occurrences" for k, v in stutter_counts.items() if v > 0])
376
+ }
377
+
378
+ progress(1.0, desc="βœ… Done!")
379
+ print("Done!")
380
+ return summary, transcription, timeline_md, defs, analysis_data
381
+
382
+ except Exception as e:
383
+ print(f"Error: {e}")
384
+ traceback.print_exc()
385
+ return f"Error: {e}\n\n{traceback.format_exc()}", "", "", "", None
386
+
387
+ print("Building UI...")
388
+
389
+ with gr.Blocks(title="Stutter Analysis", css="""
390
+ .loading-text {
391
+ font-size: 1.2em;
392
+ color: #666;
393
+ padding: 20px;
394
+ text-align: center;
395
+ }
396
+ """) as demo:
397
+ gr.Markdown("""
398
+ # πŸŽ™οΈ Speech Fluency Analysis System
399
+
400
+ Upload an audio file to analyze stuttering patterns using AI (WavLM + Whisper).
401
+
402
+ **Supported formats:** WAV, MP3, M4A, FLAC, OGG
403
+ """)
404
+
405
+ # Store analysis data for MedGemma
406
+ analysis_state = gr.State()
407
+
408
+ with gr.Row():
409
+ with gr.Column(scale=1):
410
+ audio = gr.Audio(label="🎀 Upload Audio", type="filepath")
411
+ threshold = gr.Slider(
412
+ minimum=0.3,
413
+ maximum=0.7,
414
+ value=0.5,
415
+ step=0.05,
416
+ label="Detection Threshold",
417
+ info="Lower = more sensitive, Higher = more strict"
418
+ )
419
+ btn = gr.Button("πŸ” Analyze Speech", variant="primary", size="lg")
420
+ gr.Markdown("*Analysis takes 30-60 seconds depending on audio length*")
421
+
422
+ with gr.Column(scale=2):
423
+ summary = gr.Markdown(value="### πŸ‘† Upload audio and click Analyze to start")
424
+
425
+ with gr.Tabs():
426
+ with gr.TabItem("πŸ“ Transcription"):
427
+ trans = gr.Markdown()
428
+ with gr.TabItem("πŸ“ˆ Timeline"):
429
+ timeline = gr.Markdown()
430
+ with gr.TabItem("πŸ“– Definitions"):
431
+ defs = gr.Markdown()
432
+ with gr.TabItem("πŸ₯ Clinical Report (MedGemma)"):
433
+ gr.Markdown("### Automatic Clinical Narrative Generation")
434
+ gr.Markdown("*Powered by Google MedGemma (HAI-DEF)*")
435
+ gen_btn = gr.Button("✨ Generate Professional Report", variant="secondary")
436
+ report_out = gr.Markdown("⚠️ Please run analysis first to generate report data.")
437
+
438
+ gr.Markdown("""
439
+ ---
440
+ **Note:** The spinner will appear while processing. Please wait for analysis to complete.
441
+ """)
442
+
443
+ # The show_progress parameter shows a spinner during processing
444
+ btn.click(
445
+ fn=analyze_audio,
446
+ inputs=[audio, threshold],
447
+ outputs=[summary, trans, timeline, defs, analysis_state],
448
+ show_progress="full" # Shows loading spinner
449
+ )
450
+
451
+ gen_btn.click(
452
+ fn=generate_medgemma_report,
453
+ inputs=[analysis_state],
454
+ outputs=[report_out]
455
+ )
456
+
457
+ print("Loading models...")
458
+ load_models()
459
+
460
+ print("Launching...")
461
+ demo.queue()
462
+ demo.launch(ssr_mode=False)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ transformers>=4.50.0
4
+ gradio>=4.0.0
5
+ openai-whisper>=20231117
6
+ numpy>=1.24.0
7
+ soundfile>=0.12.0
8
+ librosa>=0.10.0
9
+ accelerate>=0.26.0
10
+ bitsandbytes>=0.41.0
wavlm_stutter_classification_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b98f4e50fa40a0cd43602858d77bff78692a68b14fb6bb7144b5d2a12155071b
3
+ size 377646731