throgletworld commited on
Commit
0f70449
Β·
verified Β·
1 Parent(s): 0b7e787

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +74 -1
app.py CHANGED
@@ -9,6 +9,79 @@ import torch.nn as nn
9
 
10
  print(f"APP STARTUP: {datetime.now()}")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class WaveLmStutterClassification(nn.Module):
13
  def __init__(self, num_labels=5):
14
  super().__init__()
@@ -152,7 +225,7 @@ def analyze_audio(audio_input, threshold, progress=gr.Progress()):
152
  detected, _ = analyze_chunk(chunk, threshold)
153
  for l in detected:
154
  stutter_counts[l] += 1
155
- timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Clear"]})
156
 
157
  progress(0.75, desc="πŸ—£οΈ Transcribing with Whisper...")
158
  print("Running Whisper...")
 
9
 
10
  print(f"APP STARTUP: {datetime.now()}")
11
 
12
+ # =============================================================================
13
+ # WHY SIGMOID INSTEAD OF SOFTMAX? - A DETAILED EXPLANATION
14
+ # =============================================================================
15
+ """
16
+ MULTI-LABEL vs MULTI-CLASS CLASSIFICATION
17
+ ==========================================
18
+
19
+ Our stutter detection is a MULTI-LABEL problem:
20
+ - A single 3-second audio chunk can have MULTIPLE stutters simultaneously
21
+ - Example: Someone might have a "Block" AND a "SoundRep" in the same chunk
22
+ - Each of the 5 stutter types is INDEPENDENT of the others
23
+
24
+ SOFTMAX (❌ NOT suitable for us):
25
+ ---------------------------------
26
+ - Used for MULTI-CLASS problems where classes are MUTUALLY EXCLUSIVE
27
+ - Example: "Is this image a Cat OR a Dog?" (can't be both)
28
+ - Formula: softmax(x_i) = exp(x_i) / sum(exp(x_j)) for all j
29
+ - All probabilities MUST sum to 1.0
30
+ - Problem: If we used softmax and got [0.7, 0.1, 0.1, 0.05, 0.05]:
31
+ - It would say "70% Prolongation" but FORCE other classes to be low
32
+ - We couldn't detect multiple stutters in one chunk!
33
+
34
+ SIGMOID (βœ… CORRECT for us):
35
+ ----------------------------
36
+ - Used for MULTI-LABEL problems where classes are INDEPENDENT
37
+ - Each class gets its own independent probability (0 to 1)
38
+ - Formula: sigmoid(x) = 1 / (1 + exp(-x))
39
+ - Probabilities DON'T need to sum to 1
40
+ - Example output: [0.8, 0.7, 0.2, 0.1, 0.05]
41
+ - 80% chance of Prolongation
42
+ - 70% chance of Block
43
+ - Both can be detected simultaneously!
44
+
45
+ THE TRAINING & INFERENCE FLOW:
46
+ ==============================
47
+
48
+ TRAINING:
49
+ ---------
50
+ 1. Model outputs: LOGITS (raw scores from -∞ to +∞)
51
+ Example: [2.5, -3.0, 0.1, -1.5, -2.0]
52
+
53
+ 2. Loss Function: BCEWithLogitsLoss
54
+ - "WithLogits" means it applies Sigmoid INTERNALLY
55
+ - More numerically stable than separate Sigmoid + BCELoss
56
+ - Compares each prediction to each ground truth label independently
57
+
58
+ INFERENCE (this file):
59
+ ----------------------
60
+ 1. Model outputs: LOGITS (same as training)
61
+ Example: [2.5, -3.0, 0.1, -1.5, -2.0]
62
+
63
+ 2. We manually apply Sigmoid to convert to probabilities:
64
+ probs = torch.sigmoid(logits)
65
+ Result: [0.92, 0.05, 0.52, 0.18, 0.12]
66
+
67
+ 3. Apply threshold (e.g., 0.5) to each probability:
68
+ - 0.92 > 0.5 β†’ Prolongation DETECTED
69
+ - 0.05 < 0.5 β†’ Block NOT detected
70
+ - 0.52 > 0.5 β†’ SoundRep DETECTED
71
+ - etc.
72
+
73
+ 4. If NO stutters detected (all below threshold):
74
+ β†’ Label the chunk as "Fluent"
75
+
76
+ THRESHOLD EXPLAINED:
77
+ ====================
78
+ - Default: 0.5 (theoretically neutral, since sigmoid(0) = 0.5)
79
+ - Lower threshold (0.3-0.4): More SENSITIVE, catches more stutters, but more false positives
80
+ - Higher threshold (0.6-0.7): More STRICT, fewer false positives, but might miss subtle stutters
81
+ - The slider in the UI lets users adjust this based on their needs
82
+ - SAME threshold is applied to ALL 5 classes (simplest approach)
83
+ """
84
+
85
  class WaveLmStutterClassification(nn.Module):
86
  def __init__(self, num_labels=5):
87
  super().__init__()
 
225
  detected, _ = analyze_chunk(chunk, threshold)
226
  for l in detected:
227
  stutter_counts[l] += 1
228
+ timeline.append({"time": f"{start/sr:.1f}-{end/sr:.1f}s", "detected": detected or ["Fluent"]})
229
 
230
  progress(0.75, desc="πŸ—£οΈ Transcribing with Whisper...")
231
  print("Running Whisper...")