Spaces:
Sleeping
Sleeping
Commit ·
f95eb98
1
Parent(s): 9f65b78
Add pruned 2.1k param model and simplify postprocessing options
Browse files- Add pruned model (2.1k params) to model dropdown
- Remove EMA postprocessing options, keep only Hysteresis and Median Filtering
- Add model switching handler to reload model when selection changes
Co-Authored-By: Claude (@vertexai-global/anthropic.claude-opus-4-5@20251101) <noreply@anthropic.com>
- app.py +38 -8
- assets/pruned_2k_params/config.json +88 -0
- assets/pruned_2k_params/pruned_model.pth +3 -0
app.py
CHANGED
|
@@ -105,13 +105,10 @@ current_model_version = "V2"
|
|
| 105 |
vad_preprocessor = None
|
| 106 |
|
| 107 |
# Postprocessing options for kiloVAD
|
| 108 |
-
# Uses hysteresis thresholding with optional smoothing (median
|
| 109 |
POSTPROCESSING_OPTIONS = [
|
| 110 |
"Hysteresis Thresholding",
|
| 111 |
"Median Filtering + Hysteresis Threshold",
|
| 112 |
-
"EMA + Hysteresis Threshold (alpha=0.5)",
|
| 113 |
-
"EMA + Hysteresis Threshold (alpha=0.4)",
|
| 114 |
-
"EMA + Hysteresis Threshold (alpha=0.3)",
|
| 115 |
]
|
| 116 |
|
| 117 |
# Model paths configuration with optimized thresholds
|
|
@@ -125,6 +122,15 @@ MODEL_PATHS = {
|
|
| 125 |
"version": "V2",
|
| 126 |
"frame_sec": 0.2,
|
| 127 |
"stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
}
|
| 129 |
}
|
| 130 |
|
|
@@ -1472,7 +1478,7 @@ with gr.Blocks(title="kiloVAD") as demo:
|
|
| 1472 |
# Model Selection
|
| 1473 |
gr.Markdown("#### 🤖 Select Model")
|
| 1474 |
model_dropdown = gr.Dropdown(
|
| 1475 |
-
choices=["Unpruned, 81k params"],
|
| 1476 |
value="Unpruned, 81k params",
|
| 1477 |
label="Model",
|
| 1478 |
info="Select a kiloVAD model version"
|
|
@@ -1485,9 +1491,6 @@ with gr.Blocks(title="kiloVAD") as demo:
|
|
| 1485 |
choices=[
|
| 1486 |
"Hysteresis Thresholding",
|
| 1487 |
"Median Filtering + Hysteresis Threshold",
|
| 1488 |
-
"EMA + Hysteresis Threshold (alpha=0.5)",
|
| 1489 |
-
"EMA + Hysteresis Threshold (alpha=0.4)",
|
| 1490 |
-
"EMA + Hysteresis Threshold (alpha=0.3)",
|
| 1491 |
],
|
| 1492 |
value="Hysteresis Thresholding",
|
| 1493 |
label="Postprocessing",
|
|
@@ -1624,6 +1627,33 @@ with gr.Blocks(title="kiloVAD") as demo:
|
|
| 1624 |
show_progress=False
|
| 1625 |
)
|
| 1626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1627 |
if __name__ == "__main__":
|
| 1628 |
# Add assets directory to allowed paths for Gradio security
|
| 1629 |
assets_path = project_root / "assets"
|
|
|
|
| 105 |
vad_preprocessor = None
|
| 106 |
|
| 107 |
# Postprocessing options for kiloVAD
|
| 108 |
+
# Uses hysteresis thresholding with optional smoothing (median)
|
| 109 |
POSTPROCESSING_OPTIONS = [
|
| 110 |
"Hysteresis Thresholding",
|
| 111 |
"Median Filtering + Hysteresis Threshold",
|
|
|
|
|
|
|
|
|
|
| 112 |
]
|
| 113 |
|
| 114 |
# Model paths configuration with optimized thresholds
|
|
|
|
| 122 |
"version": "V2",
|
| 123 |
"frame_sec": 0.2,
|
| 124 |
"stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
|
| 125 |
+
},
|
| 126 |
+
"Pruned, 2.1k params": {
|
| 127 |
+
"path": project_root / "assets" / "pruned_2k_params",
|
| 128 |
+
"model_file": "pruned_model.pth",
|
| 129 |
+
"threshold_low": 0.7626,
|
| 130 |
+
"threshold_high": 0.9029,
|
| 131 |
+
"version": "V2",
|
| 132 |
+
"frame_sec": 0.2,
|
| 133 |
+
"stride_sec": 0.05 # 50ms stride with overlapping 200ms frames
|
| 134 |
}
|
| 135 |
}
|
| 136 |
|
|
|
|
| 1478 |
# Model Selection
|
| 1479 |
gr.Markdown("#### 🤖 Select Model")
|
| 1480 |
model_dropdown = gr.Dropdown(
|
| 1481 |
+
choices=["Unpruned, 81k params", "Pruned, 2.1k params"],
|
| 1482 |
value="Unpruned, 81k params",
|
| 1483 |
label="Model",
|
| 1484 |
info="Select a kiloVAD model version"
|
|
|
|
| 1491 |
choices=[
|
| 1492 |
"Hysteresis Thresholding",
|
| 1493 |
"Median Filtering + Hysteresis Threshold",
|
|
|
|
|
|
|
|
|
|
| 1494 |
],
|
| 1495 |
value="Hysteresis Thresholding",
|
| 1496 |
label="Postprocessing",
|
|
|
|
| 1627 |
show_progress=False
|
| 1628 |
)
|
| 1629 |
|
| 1630 |
+
# Handle model selection change
|
| 1631 |
+
def on_model_change(model_name, audio, postprocessing):
|
| 1632 |
+
try:
|
| 1633 |
+
success, message = load_vad_model(model_name)
|
| 1634 |
+
if not success:
|
| 1635 |
+
return f"❌ {message}", None, None
|
| 1636 |
+
|
| 1637 |
+
# Get the thresholds for the new model
|
| 1638 |
+
model_info = MODEL_PATHS.get(model_name, {})
|
| 1639 |
+
new_low = model_info.get("threshold_low", 0.7626)
|
| 1640 |
+
new_high = model_info.get("threshold_high", 0.9029)
|
| 1641 |
+
|
| 1642 |
+
if audio is None:
|
| 1643 |
+
return f"✅ {message}", None, None
|
| 1644 |
+
|
| 1645 |
+
# Re-process audio with the new model
|
| 1646 |
+
return process_uploaded_audio(audio, new_low, new_high, postprocessing)
|
| 1647 |
+
except Exception as e:
|
| 1648 |
+
return f"❌ Error switching model: {str(e)}", None, None
|
| 1649 |
+
|
| 1650 |
+
model_dropdown.change(
|
| 1651 |
+
fn=on_model_change,
|
| 1652 |
+
inputs=[model_dropdown, audio_input, postprocessing_dropdown],
|
| 1653 |
+
outputs=[status_display, vad_plot, stereo_audio_output],
|
| 1654 |
+
show_progress=True
|
| 1655 |
+
)
|
| 1656 |
+
|
| 1657 |
if __name__ == "__main__":
|
| 1658 |
# Add assets directory to allowed paths for Gradio security
|
| 1659 |
assets_path = project_root / "assets"
|
assets/pruned_2k_params/config.json
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"experiment": {
|
| 3 |
+
"name": "flexiblevad_frame200ms_seed256",
|
| 4 |
+
"debug": false,
|
| 5 |
+
"gpu_id": null,
|
| 6 |
+
"seed": 256
|
| 7 |
+
},
|
| 8 |
+
"data": {
|
| 9 |
+
"data_dir": "/mnt/azureml/cr/j/d94de43aa8c14a2eb32c4b6dcd863283/cap/data-capability/wd/INPUT_vad_data",
|
| 10 |
+
"split_name": "libri_dns_full_no_pure_noise_v2",
|
| 11 |
+
"num_workers": 12,
|
| 12 |
+
"frame_sec": 0.2,
|
| 13 |
+
"use_precomputed": true,
|
| 14 |
+
"precomputed_dir": "/tmp/vad_data/precomputed_features",
|
| 15 |
+
"use_shm_cache": false
|
| 16 |
+
},
|
| 17 |
+
"model": {
|
| 18 |
+
"class": "FlexibleVAD",
|
| 19 |
+
"n_mels": 64,
|
| 20 |
+
"frame_sec": 0.2
|
| 21 |
+
},
|
| 22 |
+
"training": {
|
| 23 |
+
"num_epochs": 40,
|
| 24 |
+
"batch_size": 512,
|
| 25 |
+
"loss": {
|
| 26 |
+
"type": "CrossEntropyLoss",
|
| 27 |
+
"label_smoothing": 0.09
|
| 28 |
+
},
|
| 29 |
+
"pauc_loss": {
|
| 30 |
+
"enabled": false,
|
| 31 |
+
"lambda_pauc": 0.8,
|
| 32 |
+
"tpr_low": 0.9,
|
| 33 |
+
"tpr_high": 1.0,
|
| 34 |
+
"gamma": 10.0,
|
| 35 |
+
"warmup_epochs": 5
|
| 36 |
+
},
|
| 37 |
+
"model_selection": {
|
| 38 |
+
"use_fpr95_for_best_model": false
|
| 39 |
+
},
|
| 40 |
+
"optimizer": {
|
| 41 |
+
"type": "SGD",
|
| 42 |
+
"base_lr": 0.0035,
|
| 43 |
+
"momentum": 0.9,
|
| 44 |
+
"nesterov": true,
|
| 45 |
+
"weight_decay": 0.000875
|
| 46 |
+
},
|
| 47 |
+
"lr_scheduler": {
|
| 48 |
+
"type": "cyclic",
|
| 49 |
+
"min_lr": 1e-05,
|
| 50 |
+
"warmup_epochs": 5,
|
| 51 |
+
"hold_epochs": 5,
|
| 52 |
+
"decay_epochs": 5,
|
| 53 |
+
"cycle_period": 20
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"dataloader": {
|
| 57 |
+
"shuffle_train": true,
|
| 58 |
+
"pin_memory": true,
|
| 59 |
+
"persistent_workers": true,
|
| 60 |
+
"prefetch_factor": 4,
|
| 61 |
+
"drop_last": false
|
| 62 |
+
},
|
| 63 |
+
"distributed": {
|
| 64 |
+
"enabled": true,
|
| 65 |
+
"backend": "nccl",
|
| 66 |
+
"auto_scale_workers": true
|
| 67 |
+
},
|
| 68 |
+
"logging": {
|
| 69 |
+
"level": "INFO",
|
| 70 |
+
"log_to_file": true,
|
| 71 |
+
"log_to_console": true
|
| 72 |
+
},
|
| 73 |
+
"evaluation": {
|
| 74 |
+
"save_best_model": true,
|
| 75 |
+
"eval_noise_rejection": true,
|
| 76 |
+
"plot_roc_curves": true,
|
| 77 |
+
"test_categories": [
|
| 78 |
+
"Clean LibriSpeech",
|
| 79 |
+
"Windy LibriSpeech",
|
| 80 |
+
"LibriSpeech+DNS Synthetic",
|
| 81 |
+
"DNS Speech+Noise",
|
| 82 |
+
"DNS Pure Noise"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
"actual_num_epochs": 40,
|
| 86 |
+
"model_total_parameters": 81090,
|
| 87 |
+
"device": "cuda:0"
|
| 88 |
+
}
|
assets/pruned_2k_params/pruned_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ff919ef23f0062e40edfd37b9a8c4f2fd99c233247a401822580adefb6d5247
|
| 3 |
+
size 35175
|