throgletworld commited on
Commit
ba86bda
·
verified ·
1 Parent(s): 0f70449

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +227 -279
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,320 +1,268 @@
1
- import gradio as gr
 
 
 
 
 
2
  import torch
 
 
3
  import numpy as np
4
- import os
5
- import traceback
6
  from datetime import datetime
7
  from transformers import WavLMModel
8
- import torch.nn as nn
9
 
10
- print(f"APP STARTUP: {datetime.now()}")
 
 
 
 
 
 
 
 
 
 
11
 
12
- # =============================================================================
13
- # WHY SIGMOID INSTEAD OF SOFTMAX? - A DETAILED EXPLANATION
14
- # =============================================================================
15
- """
16
- MULTI-LABEL vs MULTI-CLASS CLASSIFICATION
17
- ==========================================
18
-
19
- Our stutter detection is a MULTI-LABEL problem:
20
- - A single 3-second audio chunk can have MULTIPLE stutters simultaneously
21
- - Example: Someone might have a "Block" AND a "SoundRep" in the same chunk
22
- - Each of the 5 stutter types is INDEPENDENT of the others
23
-
24
- SOFTMAX (❌ NOT suitable for us):
25
- ---------------------------------
26
- - Used for MULTI-CLASS problems where classes are MUTUALLY EXCLUSIVE
27
- - Example: "Is this image a Cat OR a Dog?" (can't be both)
28
- - Formula: softmax(x_i) = exp(x_i) / sum(exp(x_j)) for all j
29
- - All probabilities MUST sum to 1.0
30
- - Problem: If we used softmax and got [0.7, 0.1, 0.1, 0.05, 0.05]:
31
- - It would say "70% Prolongation" but FORCE other classes to be low
32
- - We couldn't detect multiple stutters in one chunk!
33
-
34
- SIGMOID (✅ CORRECT for us):
35
- ----------------------------
36
- - Used for MULTI-LABEL problems where classes are INDEPENDENT
37
- - Each class gets its own independent probability (0 to 1)
38
- - Formula: sigmoid(x) = 1 / (1 + exp(-x))
39
- - Probabilities DON'T need to sum to 1
40
- - Example output: [0.8, 0.7, 0.2, 0.1, 0.05]
41
- - 80% chance of Prolongation
42
- - 70% chance of Block
43
- - Both can be detected simultaneously!
44
-
45
- THE TRAINING & INFERENCE FLOW:
46
- ==============================
47
-
48
- TRAINING:
49
- ---------
50
- 1. Model outputs: LOGITS (raw scores from -∞ to +∞)
51
- Example: [2.5, -3.0, 0.1, -1.5, -2.0]
52
-
53
- 2. Loss Function: BCEWithLogitsLoss
54
- - "WithLogits" means it applies Sigmoid INTERNALLY
55
- - More numerically stable than separate Sigmoid + BCELoss
56
- - Compares each prediction to each ground truth label independently
57
-
58
- INFERENCE (this file):
59
- ----------------------
60
- 1. Model outputs: LOGITS (same as training)
61
- Example: [2.5, -3.0, 0.1, -1.5, -2.0]
62
-
63
- 2. We manually apply Sigmoid to convert to probabilities:
64
- probs = torch.sigmoid(logits)
65
- Result: [0.92, 0.05, 0.52, 0.18, 0.12]
66
-
67
- 3. Apply threshold (e.g., 0.5) to each probability:
68
- - 0.92 > 0.5 → Prolongation DETECTED
69
- - 0.05 < 0.5 → Block NOT detected
70
- - 0.52 > 0.5 → SoundRep DETECTED
71
- - etc.
72
-
73
- 4. If NO stutters detected (all below threshold):
74
- → Label the chunk as "Fluent"
75
-
76
- THRESHOLD EXPLAINED:
77
- ====================
78
- - Default: 0.5 (theoretically neutral, since sigmoid(0) = 0.5)
79
- - Lower threshold (0.3-0.4): More SENSITIVE, catches more stutters, but more false positives
80
- - Higher threshold (0.6-0.7): More STRICT, fewer false positives, but might miss subtle stutters
81
- - The slider in the UI lets users adjust this based on their needs
82
- - SAME threshold is applied to ALL 5 classes (simplest approach)
83
- """
84
 
85
  class WaveLmStutterClassification(nn.Module):
86
  def __init__(self, num_labels=5):
87
  super().__init__()
88
  self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
89
  self.hidden_size = self.wavlm.config.hidden_size
90
- for param in self.wavlm.parameters():
91
- param.requires_grad = False
92
  self.classifier = nn.Linear(self.hidden_size, num_labels)
93
- self.num_labels = num_labels
94
-
95
- def forward(self, input_values, attention_mask=None):
96
- outputs = self.wavlm(input_values, attention_mask=attention_mask)
97
- hidden_states = outputs.last_hidden_state
98
- pooled = hidden_states.mean(dim=1)
99
- logits = self.classifier(pooled)
100
- return logits
101
-
102
- STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection']
103
-
104
- STUTTER_DEFINITIONS = {
105
- 'Prolongation': 'Sound stretched longer than normal',
106
- 'Block': 'Complete stoppage of airflow/sound',
107
- 'SoundRep': 'Sound/syllable repetition',
108
- 'WordRep': 'Whole word repetition',
109
- 'Interjection': 'Filler words like um, uh'
110
- }
111
 
112
- device = "cuda" if torch.cuda.is_available() else "cpu"
113
- print(f"Device: {device}")
 
 
114
 
115
  wavlm_model = None
116
  whisper_model = None
117
  models_loaded = False
118
 
 
119
  def load_models():
 
120
  global wavlm_model, whisper_model, models_loaded
121
  if models_loaded:
122
  return True
123
- try:
124
- print("Loading WavLM...")
125
- wavlm_model = WaveLmStutterClassification(num_labels=5)
126
- checkpoint_path = "wavlm_stutter_classification_best.pth"
127
- if os.path.exists(checkpoint_path):
128
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
129
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
130
- wavlm_model.load_state_dict(checkpoint['model_state_dict'])
131
- else:
132
- wavlm_model.load_state_dict(checkpoint)
133
- print("Checkpoint loaded!")
134
- wavlm_model.to(device)
135
- wavlm_model.eval()
136
-
137
- print("Loading Whisper...")
138
- import whisper
139
- whisper_model = whisper.load_model("base", device=device)
140
-
141
- models_loaded = True
142
- print("Models loaded!")
143
- return True
144
- except Exception as e:
145
- print(f"Model loading error: {e}")
146
- traceback.print_exc()
147
- return False
148
-
149
- def load_audio(audio_path):
150
- print(f"Loading: {audio_path}")
151
- try:
152
- import librosa
153
- waveform, sr = librosa.load(audio_path, sr=16000, mono=True)
154
- return torch.from_numpy(waveform).float(), 16000
155
- except Exception as e:
156
- print(f"librosa error: {e}")
157
- try:
158
- import soundfile as sf
159
- waveform, sr = sf.read(audio_path, dtype='float32')
160
- if len(waveform.shape) > 1:
161
- waveform = waveform.mean(axis=1)
162
- waveform = torch.from_numpy(waveform).float()
163
- if sr != 16000:
164
- import torchaudio
165
- waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0)
166
- return waveform, 16000
167
- except Exception as e:
168
- print(f"soundfile error: {e}")
169
- raise Exception("Could not load audio")
170
-
171
- def analyze_chunk(chunk_tensor, threshold=0.5):
172
  with torch.no_grad():
173
- logits = wavlm_model(chunk_tensor.unsqueeze(0).to(device))
174
  probs = torch.sigmoid(logits).cpu().numpy()[0]
175
  detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
176
- return detected, dict(zip(STUTTER_LABELS, probs.tolist()))
177
-
178
- def analyze_audio(audio_input, threshold, progress=gr.Progress()):
179
- print(f"\n=== ANALYZE CLICKED ===")
180
- print(f"Input: {audio_input}, Type: {type(audio_input)}, Threshold: {threshold}")
181
-
182
- progress(0, desc="🔄 Starting analysis...")
183
-
184
- if audio_input is None:
185
- return "⚠️ Please upload an audio file first!", "", "", ""
186
-
187
- audio_path = audio_input
188
- if isinstance(audio_input, tuple):
189
  import tempfile, soundfile as sf
190
- sr, data = audio_input
191
- f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
192
- sf.write(f.name, data, sr)
193
- audio_path = f.name
194
-
195
- if not os.path.exists(audio_path):
196
- return f"File not found: {audio_path}", "", "", ""
197
-
198
- print(f"File: {audio_path}, Size: {os.path.getsize(audio_path)}")
199
-
200
- try:
201
- progress(0.1, desc="🔄 Loading models...")
202
- if not models_loaded and not load_models():
203
- return "❌ Failed to load models", "", "", ""
204
-
205
- progress(0.2, desc="🎵 Loading audio file...")
206
- waveform, sr = load_audio(audio_path)
207
- duration = len(waveform) / sr
208
- print(f"Duration: {duration:.1f}s")
209
-
210
- progress(0.3, desc="✂️ Splitting audio into chunks...")
211
- chunk_samples = int(3.0 * sr)
212
- stutter_counts = {l: 0 for l in STUTTER_LABELS}
213
- timeline = []
214
-
215
- total_chunks = (len(waveform) + chunk_samples - 1) // chunk_samples
216
-
217
- for i, start in enumerate(range(0, len(waveform), chunk_samples)):
218
- progress(0.3 + (0.4 * i / total_chunks), desc=f"🔍 Analyzing chunk {i+1}/{total_chunks}...")
219
-
220
- end = min(start + chunk_samples, len(waveform))
221
- chunk = waveform[start:end]
222
- if len(chunk) < chunk_samples:
223
- chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
224
-
225
- detected, _ = analyze_chunk(chunk, threshold)
226
- for l in detected:
227
- stutter_counts[l] += 1
228
- timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Fluent"]})
229
-
230
- progress(0.75, desc="🗣️ Transcribing with Whisper...")
231
- print("Running Whisper...")
232
- transcription = whisper_model.transcribe(audio_path).get('text', '')
233
-
234
- progress(0.9, desc="📊 Generating report...")
235
- total = sum(stutter_counts.values())
236
- summary = f"## Analysis Complete!\n\n**Duration:** {duration:.1f}s\n**Total Stutters Detected:** {total}\n\n### Stutter Counts:\n"
237
- for l, c in stutter_counts.items():
238
- emoji = "🔴" if c > 0 else "⚪"
239
- summary += f"- {emoji} **{l}**: {c}\n"
240
-
241
- timeline_md = "| Time | Detected |\n|---|---|\n"
242
- for t in timeline[:15]:
243
- timeline_md += f"| {t['time']} | {', '.join(t['detected'])} |\n"
244
- if len(timeline) > 15:
245
- timeline_md += f"\n*...and {len(timeline) - 15} more chunks*"
246
-
247
- defs = "## 📖 Stutter Type Definitions\n\n"
248
- defs += "\n".join([f"**{k}:** {v}" for k, v in STUTTER_DEFINITIONS.items()])
249
-
250
- progress(1.0, desc=" Done!")
251
- print("Done!")
252
- return summary, transcription, timeline_md, defs
253
-
254
- except Exception as e:
255
- print(f"Error: {e}")
256
- traceback.print_exc()
257
- return f"Error: {e}\n\n{traceback.format_exc()}", "", "", ""
258
-
259
- print("Building UI...")
260
-
261
- with gr.Blocks(title="Stutter Analysis", css="""
262
- .loading-text {
263
- font-size: 1.2em;
264
- color: #666;
265
- padding: 20px;
266
- text-align: center;
267
- }
268
- """) as demo:
269
- gr.Markdown("""
270
- # 🎙️ Speech Fluency Analysis System
271
-
272
- Upload an audio file to analyze stuttering patterns using AI (WavLM + Whisper).
273
-
274
- **Supported formats:** WAV, MP3, M4A, FLAC, OGG
275
- """)
276
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  with gr.Row():
278
  with gr.Column(scale=1):
279
- audio = gr.Audio(label="🎤 Upload Audio", type="filepath")
280
  threshold = gr.Slider(
281
- minimum=0.3,
282
- maximum=0.7,
283
- value=0.5,
284
- step=0.05,
285
  label="Detection Threshold",
286
- info="Lower = more sensitive, Higher = more strict"
287
  )
288
- btn = gr.Button("🔍 Analyze Speech", variant="primary", size="lg")
289
- gr.Markdown("*Analysis takes 30-60 seconds depending on audio length*")
290
-
291
  with gr.Column(scale=2):
292
- summary = gr.Markdown(value="### 👆 Upload audio and click Analyze to start")
293
-
294
  with gr.Tabs():
295
- with gr.TabItem("📝 Transcription"):
296
- trans = gr.Markdown()
297
- with gr.TabItem("📈 Timeline"):
298
- timeline = gr.Markdown()
299
- with gr.TabItem("📖 Definitions"):
300
- defs = gr.Markdown()
301
-
302
- gr.Markdown("""
303
- ---
304
- **Note:** The spinner will appear while processing. Please wait for analysis to complete.
305
- """)
306
-
307
- # The show_progress parameter shows a spinner during processing
308
  btn.click(
309
- fn=analyze_audio,
310
- inputs=[audio, threshold],
311
- outputs=[summary, trans, timeline, defs],
312
- show_progress="full" # Shows loading spinner
313
  )
314
 
315
- print("Loading models...")
316
  load_models()
317
 
318
- print("Launching...")
319
  demo.queue()
320
  demo.launch(ssr_mode=False)
 
1
+ """
2
+ Speech Fluency Analysis - Hugging Face Gradio App
3
+ WavLM stutter detection + Whisper transcription.
4
+ """
5
+
6
+ import os
7
  import torch
8
+ import torch.nn as nn
9
+ import torchaudio
10
  import numpy as np
11
+ import gradio as gr
 
12
  from datetime import datetime
13
  from transformers import WavLMModel
 
14
 
15
+ STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"]
16
+
17
+ STUTTER_INFO = {
18
+ "Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')",
19
+ "Block": "Complete stoppage of airflow/sound with tension",
20
+ "SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')",
21
+ "WordRep": "Whole word repetition (e.g. 'I-I-I want')",
22
+ "Interjection": "Filler words like 'um', 'uh', 'like'",
23
+ }
24
+
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  class WaveLmStutterClassification(nn.Module):
29
  def __init__(self, num_labels=5):
30
  super().__init__()
31
  self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
32
  self.hidden_size = self.wavlm.config.hidden_size
33
+ for p in self.wavlm.parameters():
34
+ p.requires_grad = False
35
  self.classifier = nn.Linear(self.hidden_size, num_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def forward(self, x, attention_mask=None):
38
+ h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state
39
+ return self.classifier(h.mean(dim=1))
40
+
41
 
42
  wavlm_model = None
43
  whisper_model = None
44
  models_loaded = False
45
 
46
+
47
  def load_models():
48
+ """Load WavLM checkpoint and Whisper once."""
49
  global wavlm_model, whisper_model, models_loaded
50
  if models_loaded:
51
  return True
52
+
53
+ print("Loading WavLM ...")
54
+ wavlm_model = WaveLmStutterClassification(num_labels=5)
55
+ ckpt = "wavlm_stutter_classification_best.pth"
56
+ if os.path.exists(ckpt):
57
+ state = torch.load(ckpt, map_location=DEVICE, weights_only=False)
58
+ if isinstance(state, dict) and "model_state_dict" in state:
59
+ wavlm_model.load_state_dict(state["model_state_dict"])
60
+ else:
61
+ wavlm_model.load_state_dict(state)
62
+ wavlm_model.to(DEVICE).eval()
63
+
64
+ print("Loading Whisper ...")
65
+ import whisper
66
+ whisper_model = whisper.load_model("base", device=DEVICE)
67
+
68
+ models_loaded = True
69
+ print("Models ready.")
70
+ return True
71
+
72
+
73
+ # FFmpeg explained:
74
+ # torchaudio.load() uses FFmpeg under the hood as a system-level library to
75
+ # DECODE compressed audio formats (mp3, m4a, ogg, flac) into raw PCM samples.
76
+ # FFmpeg is a CLI/OS tool - torchaudio calls it via its C backend.
77
+ # The decoded PCM data is then wrapped into a torch.Tensor (the waveform).
78
+ #
79
+ # Pipeline: audio file -> FFmpeg decodes -> raw samples -> torch.Tensor
80
+ #
81
+ # packages.txt lists "ffmpeg" so HF Spaces installs it at OS level.
82
+
83
+ def load_audio(path):
84
+ """Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg)."""
85
+ waveform, sr = torchaudio.load(path)
86
+ if waveform.size(0) > 1:
87
+ waveform = waveform.mean(dim=0, keepdim=True)
88
+ if sr != 16000:
89
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
90
+ return waveform.squeeze(0), 16000
91
+
92
+
93
+ def analyze_chunk(chunk, threshold=0.5):
94
+ """Run WavLM on a single chunk."""
 
 
 
 
 
 
95
  with torch.no_grad():
96
+ logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE))
97
  probs = torch.sigmoid(logits).cpu().numpy()[0]
98
  detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold]
99
+ prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs]))
100
+ return detected, prob_dict
101
+
102
+
103
+ def analyze_audio(audio_path, threshold, progress=gr.Progress()):
104
+ """Main pipeline: chunk -> WavLM -> Whisper -> formatted results."""
105
+ if audio_path is None:
106
+ return "Upload an audio file first.", "", "", ""
107
+
108
+ if isinstance(audio_path, tuple):
 
 
 
109
  import tempfile, soundfile as sf
110
+ sr, data = audio_path
111
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
112
+ sf.write(tmp.name, data, sr)
113
+ audio_path = tmp.name
114
+
115
+ progress(0.05, desc="Loading models ...")
116
+ if not models_loaded and not load_models():
117
+ return "Failed to load models.", "", "", ""
118
+
119
+ progress(0.15, desc="Loading audio ...")
120
+ waveform, sr = load_audio(audio_path)
121
+ duration = len(waveform) / sr
122
+
123
+ progress(0.25, desc="Detecting stutters ...")
124
+ chunk_samples = 3 * sr
125
+ counts = {l: 0 for l in STUTTER_LABELS}
126
+ timeline_rows = []
127
+ total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples)
128
+
129
+ for i, start in enumerate(range(0, len(waveform), chunk_samples)):
130
+ progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...")
131
+ end = min(start + chunk_samples, len(waveform))
132
+ chunk = waveform[start:end]
133
+ if len(chunk) < chunk_samples:
134
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
135
+
136
+ detected, probs = analyze_chunk(chunk, threshold)
137
+ for label in detected:
138
+ counts[label] += 1
139
+
140
+ time_str = f"{start/sr:.1f}-{end/sr:.1f}s"
141
+ timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs})
142
+
143
+ progress(0.75, desc="Transcribing ...")
144
+ transcription = whisper_model.transcribe(audio_path).get("text", "").strip()
145
+
146
+ progress(0.90, desc="Building report ...")
147
+ total_stutters = sum(counts.values())
148
+ chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"])
149
+ stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0
150
+ word_count = len(transcription.split()) if transcription else 0
151
+ wpm = (word_count / duration) * 60 if duration > 0 else 0
152
+
153
+ severity = (
154
+ "Very Mild" if stutter_pct < 5 else
155
+ "Mild" if stutter_pct < 10 else
156
+ "Moderate" if stutter_pct < 20 else
157
+ "Severe" if stutter_pct < 30 else
158
+ "Very Severe"
159
+ )
160
+
161
+ summary_lines = [
162
+ "## Analysis Results\n",
163
+ "| Metric | Value |",
164
+ "|--------|-------|",
165
+ f"| Duration | {duration:.1f}s |",
166
+ f"| Words | {word_count} |",
167
+ f"| Speaking Rate | {wpm:.0f} wpm |",
168
+ f"| Stutter Events | {total_stutters} |",
169
+ f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |",
170
+ f"| Severity | **{severity}** |",
171
+ "",
172
+ "### Stutter Counts",
173
+ "",
174
+ ]
175
+ for label in STUTTER_LABELS:
176
+ c = counts[label]
177
+ bar = "X" * min(c, 20)
178
+ icon = "!" if c > 0 else "o"
179
+ summary_lines.append(f"- {icon} **{label}**: {c} {bar}")
180
+
181
+ summary_md = "\n".join(summary_lines)
182
+
183
+ tl_lines = ["| Time | Detected |", "|------|----------|"]
184
+ for row in timeline_rows:
185
+ tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |")
186
+ timeline_md = "\n".join(tl_lines)
187
+
188
+ recs = ["## Recommendations\n"]
189
+ if severity in ("Very Mild", "Mild"):
190
+ recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.")
191
+ elif severity == "Moderate":
192
+ recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.")
193
+ else:
194
+ recs.append("- Professional speech-language pathology evaluation is strongly recommended.")
195
+
196
+ dominant = max(counts, key=counts.get)
197
+ if counts[dominant] > 0:
198
+ recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}")
199
+
200
+ if wpm > 180:
201
+ recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.")
202
+
203
+ recs.append("\n### Stutter Type Definitions\n")
204
+ for label, desc in STUTTER_INFO.items():
205
+ recs.append(f"- **{label}**: {desc}")
206
+
207
+ recs_md = "\n".join(recs)
208
+
209
+ progress(1.0, desc="Done!")
210
+ return summary_md, transcription, timeline_md, recs_md
211
+
212
+
213
+ CUSTOM_CSS = """
214
+ .gradio-container { max-width: 960px !important; }
215
+ .gr-button-primary { background: #0f766e !important; }
216
+ """
217
+
218
+ with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
219
+
220
+ gr.Markdown(
221
+ """
222
+ # Speech Fluency Analysis
223
+ Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection)
224
+ and **Whisper** (transcription).
225
+
226
+ Supported formats: **WAV, MP3, M4A, FLAC, OGG**
227
+ """
228
+ )
229
+
230
  with gr.Row():
231
  with gr.Column(scale=1):
232
+ audio_in = gr.Audio(label="Upload Audio", type="filepath")
233
  threshold = gr.Slider(
234
+ 0.3, 0.7, value=0.5, step=0.05,
 
 
 
235
  label="Detection Threshold",
236
+ info="Lower = more sensitive, Higher = more strict",
237
  )
238
+ btn = gr.Button("Analyze", variant="primary", size="lg")
239
+
 
240
  with gr.Column(scale=2):
241
+ summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*")
242
+
243
  with gr.Tabs():
244
+ with gr.TabItem("Transcription"):
245
+ trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False)
246
+ with gr.TabItem("Timeline"):
247
+ timeline_out = gr.Markdown()
248
+ with gr.TabItem("Recommendations"):
249
+ recs_out = gr.Markdown()
250
+
251
+ gr.Markdown(
252
+ "---\n*Disclaimer: AI-assisted analysis for clinical support only. "
253
+ "Consult a qualified Speech-Language Pathologist for diagnosis.*"
254
+ )
255
+
 
256
  btn.click(
257
+ fn=analyze_audio,
258
+ inputs=[audio_in, threshold],
259
+ outputs=[summary_out, trans_out, timeline_out, recs_out],
260
+ show_progress="full",
261
  )
262
 
263
+ print("Loading models at startup ...")
264
  load_models()
265
 
266
+ print("Launching Gradio ...")
267
  demo.queue()
268
  demo.launch(ssr_mode=False)
requirements.txt CHANGED
@@ -5,4 +5,4 @@ gradio>=4.0.0
5
  openai-whisper>=20231117
6
  numpy>=1.24.0
7
  soundfile>=0.12.0
8
- librosa>=0.10.0
 
5
  openai-whisper>=20231117
6
  numpy>=1.24.0
7
  soundfile>=0.12.0
8
+ huggingface_hub>=0.19.0