sheilaseidel Claude (@vertexai-global/anthropic.claude-opus-4-5@20251101) commited on
Commit
f95eb98
·
1 Parent(s): 9f65b78

Add pruned 2.1k param model and simplify postprocessing options

Browse files

- Add pruned model (2.1k params) to model dropdown
- Remove EMA postprocessing options, keep only Hysteresis and Median Filtering
- Add model switching handler to reload model when selection changes

Co-Authored-By: Claude (@vertexai-global/anthropic.claude-opus-4-5@20251101) <noreply@anthropic.com>

app.py CHANGED
@@ -105,13 +105,10 @@ current_model_version = "V2"
105
  vad_preprocessor = None
106
 
107
  # Postprocessing options for kiloVAD
108
- # Uses hysteresis thresholding with optional smoothing (median or EMA)
109
  POSTPROCESSING_OPTIONS = [
110
  "Hysteresis Thresholding",
111
  "Median Filtering + Hysteresis Threshold",
112
- "EMA + Hysteresis Threshold (alpha=0.5)",
113
- "EMA + Hysteresis Threshold (alpha=0.4)",
114
- "EMA + Hysteresis Threshold (alpha=0.3)",
115
  ]
116
 
117
  # Model paths configuration with optimized thresholds
@@ -125,6 +122,15 @@ MODEL_PATHS = {
125
  "version": "V2",
126
  "frame_sec": 0.2,
127
  "stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
 
 
 
 
 
 
 
 
 
128
  }
129
  }
130
 
@@ -1472,7 +1478,7 @@ with gr.Blocks(title="kiloVAD") as demo:
1472
  # Model Selection
1473
  gr.Markdown("#### 🤖 Select Model")
1474
  model_dropdown = gr.Dropdown(
1475
- choices=["Unpruned, 81k params"],
1476
  value="Unpruned, 81k params",
1477
  label="Model",
1478
  info="Select a kiloVAD model version"
@@ -1485,9 +1491,6 @@ with gr.Blocks(title="kiloVAD") as demo:
1485
  choices=[
1486
  "Hysteresis Thresholding",
1487
  "Median Filtering + Hysteresis Threshold",
1488
- "EMA + Hysteresis Threshold (alpha=0.5)",
1489
- "EMA + Hysteresis Threshold (alpha=0.4)",
1490
- "EMA + Hysteresis Threshold (alpha=0.3)",
1491
  ],
1492
  value="Hysteresis Thresholding",
1493
  label="Postprocessing",
@@ -1624,6 +1627,33 @@ with gr.Blocks(title="kiloVAD") as demo:
1624
  show_progress=False
1625
  )
1626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1627
  if __name__ == "__main__":
1628
  # Add assets directory to allowed paths for Gradio security
1629
  assets_path = project_root / "assets"
 
105
  vad_preprocessor = None
106
 
107
  # Postprocessing options for kiloVAD
108
+ # Uses hysteresis thresholding with optional smoothing (median)
109
  POSTPROCESSING_OPTIONS = [
110
  "Hysteresis Thresholding",
111
  "Median Filtering + Hysteresis Threshold",
 
 
 
112
  ]
113
 
114
  # Model paths configuration with optimized thresholds
 
122
  "version": "V2",
123
  "frame_sec": 0.2,
124
  "stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
125
+ },
126
+ "Pruned, 2.1k params": {
127
+ "path": project_root / "assets" / "pruned_2k_params",
128
+ "model_file": "pruned_model.pth",
129
+ "threshold_low": 0.7626,
130
+ "threshold_high": 0.9029,
131
+ "version": "V2",
132
+ "frame_sec": 0.2,
133
+ "stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
134
  }
135
  }
136
 
 
1478
  # Model Selection
1479
  gr.Markdown("#### 🤖 Select Model")
1480
  model_dropdown = gr.Dropdown(
1481
+ choices=["Unpruned, 81k params", "Pruned, 2.1k params"],
1482
  value="Unpruned, 81k params",
1483
  label="Model",
1484
  info="Select a kiloVAD model version"
 
1491
  choices=[
1492
  "Hysteresis Thresholding",
1493
  "Median Filtering + Hysteresis Threshold",
 
 
 
1494
  ],
1495
  value="Hysteresis Thresholding",
1496
  label="Postprocessing",
 
1627
  show_progress=False
1628
  )
1629
 
1630
+ # Handle model selection change
1631
+ def on_model_change(model_name, audio, postprocessing):
1632
+ try:
1633
+ success, message = load_vad_model(model_name)
1634
+ if not success:
1635
+ return f"❌ {message}", None, None
1636
+
1637
+ # Get the thresholds for the new model
1638
+ model_info = MODEL_PATHS.get(model_name, {})
1639
+ new_low = model_info.get("threshold_low", 0.7626)
1640
+ new_high = model_info.get("threshold_high", 0.9029)
1641
+
1642
+ if audio is None:
1643
+ return f"✅ {message}", None, None
1644
+
1645
+ # Re-process audio with the new model
1646
+ return process_uploaded_audio(audio, new_low, new_high, postprocessing)
1647
+ except Exception as e:
1648
+ return f"❌ Error switching model: {str(e)}", None, None
1649
+
1650
+ model_dropdown.change(
1651
+ fn=on_model_change,
1652
+ inputs=[model_dropdown, audio_input, postprocessing_dropdown],
1653
+ outputs=[status_display, vad_plot, stereo_audio_output],
1654
+ show_progress=True
1655
+ )
1656
+
1657
  if __name__ == "__main__":
1658
  # Add assets directory to allowed paths for Gradio security
1659
  assets_path = project_root / "assets"
assets/pruned_2k_params/config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment": {
3
+ "name": "flexiblevad_frame200ms_seed256",
4
+ "debug": false,
5
+ "gpu_id": null,
6
+ "seed": 256
7
+ },
8
+ "data": {
9
+ "data_dir": "/mnt/azureml/cr/j/d94de43aa8c14a2eb32c4b6dcd863283/cap/data-capability/wd/INPUT_vad_data",
10
+ "split_name": "libri_dns_full_no_pure_noise_v2",
11
+ "num_workers": 12,
12
+ "frame_sec": 0.2,
13
+ "use_precomputed": true,
14
+ "precomputed_dir": "/tmp/vad_data/precomputed_features",
15
+ "use_shm_cache": false
16
+ },
17
+ "model": {
18
+ "class": "FlexibleVAD",
19
+ "n_mels": 64,
20
+ "frame_sec": 0.2
21
+ },
22
+ "training": {
23
+ "num_epochs": 40,
24
+ "batch_size": 512,
25
+ "loss": {
26
+ "type": "CrossEntropyLoss",
27
+ "label_smoothing": 0.09
28
+ },
29
+ "pauc_loss": {
30
+ "enabled": false,
31
+ "lambda_pauc": 0.8,
32
+ "tpr_low": 0.9,
33
+ "tpr_high": 1.0,
34
+ "gamma": 10.0,
35
+ "warmup_epochs": 5
36
+ },
37
+ "model_selection": {
38
+ "use_fpr95_for_best_model": false
39
+ },
40
+ "optimizer": {
41
+ "type": "SGD",
42
+ "base_lr": 0.0035,
43
+ "momentum": 0.9,
44
+ "nesterov": true,
45
+ "weight_decay": 0.000875
46
+ },
47
+ "lr_scheduler": {
48
+ "type": "cyclic",
49
+ "min_lr": 1e-05,
50
+ "warmup_epochs": 5,
51
+ "hold_epochs": 5,
52
+ "decay_epochs": 5,
53
+ "cycle_period": 20
54
+ }
55
+ },
56
+ "dataloader": {
57
+ "shuffle_train": true,
58
+ "pin_memory": true,
59
+ "persistent_workers": true,
60
+ "prefetch_factor": 4,
61
+ "drop_last": false
62
+ },
63
+ "distributed": {
64
+ "enabled": true,
65
+ "backend": "nccl",
66
+ "auto_scale_workers": true
67
+ },
68
+ "logging": {
69
+ "level": "INFO",
70
+ "log_to_file": true,
71
+ "log_to_console": true
72
+ },
73
+ "evaluation": {
74
+ "save_best_model": true,
75
+ "eval_noise_rejection": true,
76
+ "plot_roc_curves": true,
77
+ "test_categories": [
78
+ "Clean LibriSpeech",
79
+ "Windy LibriSpeech",
80
+ "LibriSpeech+DNS Synthetic",
81
+ "DNS Speech+Noise",
82
+ "DNS Pure Noise"
83
+ ]
84
+ },
85
+ "actual_num_epochs": 40,
86
+ "model_total_parameters": 81090,
87
+ "device": "cuda:0"
88
+ }
assets/pruned_2k_params/pruned_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ff919ef23f0062e40edfd37b9a8c4f2fd99c233247a401822580adefb6d5247
3
+ size 35175