Upload 3 files
Browse files
app.py
CHANGED
|
@@ -9,6 +9,79 @@ import torch.nn as nn
|
|
| 9 |
|
| 10 |
print(f"APP STARTUP: {datetime.now()}")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class WaveLmStutterClassification(nn.Module):
|
| 13 |
def __init__(self, num_labels=5):
|
| 14 |
super().__init__()
|
|
@@ -152,7 +225,7 @@ def analyze_audio(audio_input, threshold, progress=gr.Progress()):
|
|
| 152 |
detected, _ = analyze_chunk(chunk, threshold)
|
| 153 |
for l in detected:
|
| 154 |
stutter_counts[l] += 1
|
| 155 |
-
timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["
|
| 156 |
|
| 157 |
progress(0.75, desc="π£οΈ Transcribing with Whisper...")
|
| 158 |
print("Running Whisper...")
|
|
|
|
| 9 |
|
| 10 |
print(f"APP STARTUP: {datetime.now()}")
|
| 11 |
|
| 12 |
+
# =============================================================================
|
| 13 |
+
# WHY SIGMOID INSTEAD OF SOFTMAX? - A DETAILED EXPLANATION
|
| 14 |
+
# =============================================================================
|
| 15 |
+
"""
|
| 16 |
+
MULTI-LABEL vs MULTI-CLASS CLASSIFICATION
|
| 17 |
+
==========================================
|
| 18 |
+
|
| 19 |
+
Our stutter detection is a MULTI-LABEL problem:
|
| 20 |
+
- A single 3-second audio chunk can have MULTIPLE stutters simultaneously
|
| 21 |
+
- Example: Someone might have a "Block" AND a "SoundRep" in the same chunk
|
| 22 |
+
- Each of the 5 stutter types is INDEPENDENT of the others
|
| 23 |
+
|
| 24 |
+
SOFTMAX (β NOT suitable for us):
|
| 25 |
+
---------------------------------
|
| 26 |
+
- Used for MULTI-CLASS problems where classes are MUTUALLY EXCLUSIVE
|
| 27 |
+
- Example: "Is this image a Cat OR a Dog?" (can't be both)
|
| 28 |
+
- Formula: softmax(x_i) = exp(x_i) / sum(exp(x_j)) for all j
|
| 29 |
+
- All probabilities MUST sum to 1.0
|
| 30 |
+
- Problem: If we used softmax and got [0.7, 0.1, 0.1, 0.05, 0.05]:
|
| 31 |
+
- It would say "70% Prolongation" but FORCE other classes to be low
|
| 32 |
+
- We couldn't detect multiple stutters in one chunk!
|
| 33 |
+
|
| 34 |
+
SIGMOID (β
CORRECT for us):
|
| 35 |
+
----------------------------
|
| 36 |
+
- Used for MULTI-LABEL problems where classes are INDEPENDENT
|
| 37 |
+
- Each class gets its own independent probability (0 to 1)
|
| 38 |
+
- Formula: sigmoid(x) = 1 / (1 + exp(-x))
|
| 39 |
+
- Probabilities DON'T need to sum to 1
|
| 40 |
+
- Example output: [0.8, 0.7, 0.2, 0.1, 0.05]
|
| 41 |
+
- 80% chance of Prolongation
|
| 42 |
+
- 70% chance of Block
|
| 43 |
+
- Both can be detected simultaneously!
|
| 44 |
+
|
| 45 |
+
THE TRAINING & INFERENCE FLOW:
|
| 46 |
+
==============================
|
| 47 |
+
|
| 48 |
+
TRAINING:
|
| 49 |
+
---------
|
| 50 |
+
1. Model outputs: LOGITS (raw scores from -β to +β)
|
| 51 |
+
Example: [2.5, -3.0, 0.1, -1.5, -2.0]
|
| 52 |
+
|
| 53 |
+
2. Loss Function: BCEWithLogitsLoss
|
| 54 |
+
- "WithLogits" means it applies Sigmoid INTERNALLY
|
| 55 |
+
- More numerically stable than separate Sigmoid + BCELoss
|
| 56 |
+
- Compares each prediction to each ground truth label independently
|
| 57 |
+
|
| 58 |
+
INFERENCE (this file):
|
| 59 |
+
----------------------
|
| 60 |
+
1. Model outputs: LOGITS (same as training)
|
| 61 |
+
Example: [2.5, -3.0, 0.1, -1.5, -2.0]
|
| 62 |
+
|
| 63 |
+
2. We manually apply Sigmoid to convert to probabilities:
|
| 64 |
+
probs = torch.sigmoid(logits)
|
| 65 |
+
Result: [0.92, 0.05, 0.52, 0.18, 0.12]
|
| 66 |
+
|
| 67 |
+
3. Apply threshold (e.g., 0.5) to each probability:
|
| 68 |
+
- 0.92 > 0.5 β Prolongation DETECTED
|
| 69 |
+
- 0.05 < 0.5 β Block NOT detected
|
| 70 |
+
- 0.52 > 0.5 β SoundRep DETECTED
|
| 71 |
+
- etc.
|
| 72 |
+
|
| 73 |
+
4. If NO stutters detected (all below threshold):
|
| 74 |
+
β Label the chunk as "Fluent"
|
| 75 |
+
|
| 76 |
+
THRESHOLD EXPLAINED:
|
| 77 |
+
====================
|
| 78 |
+
- Default: 0.5 (theoretically neutral, since sigmoid(0) = 0.5)
|
| 79 |
+
- Lower threshold (0.3-0.4): More SENSITIVE, catches more stutters, but more false positives
|
| 80 |
+
- Higher threshold (0.6-0.7): More STRICT, fewer false positives, but might miss subtle stutters
|
| 81 |
+
- The slider in the UI lets users adjust this based on their needs
|
| 82 |
+
- SAME threshold is applied to ALL 5 classes (simplest approach)
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
class WaveLmStutterClassification(nn.Module):
|
| 86 |
def __init__(self, num_labels=5):
|
| 87 |
super().__init__()
|
|
|
|
| 225 |
detected, _ = analyze_chunk(chunk, threshold)
|
| 226 |
for l in detected:
|
| 227 |
stutter_counts[l] += 1
|
| 228 |
+
timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Fluent"]})
|
| 229 |
|
| 230 |
progress(0.75, desc="π£οΈ Transcribing with Whisper...")
|
| 231 |
print("Running Whisper...")
|