jaesungbae commited on
Commit
8096e64
·
verified ·
1 Parent(s): 3489d26

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - speech
5
+ - dysarthria
6
+ - severity-estimation
7
+ - whisper
8
+ - audio-classification
9
+ language:
10
+ - en
11
+ pipeline_tag: audio-classification
12
+ ---
13
+
14
+ # Dysarthric Speech Severity Level Classifier
15
+
16
+ A regression probe trained on top of Whisper-large-v3 encoder features for estimating the severity level of dysarthric speech.
17
+
18
+ **Score scale:** 1.0 (most severe dysarthria) to 7.0 (typical speech)
19
+
20
+ ## Model Description
21
+
22
+ This model uses a three-stage training pipeline:
23
+ 1. **Pseudo-labeling** — A baseline probe generates pseudo-labels for unlabeled data
24
+ 2. **Contrastive pre-training** — Weakly-supervised contrastive learning with typical speech augmentation
25
+ 3. **Fine-tuning** — Regression probe fine-tuned with the pre-trained projector
26
+
27
+ **Architecture:** Whisper-large-v3 encoder (frozen) → LayerNorm → 2-layer MLP (proj_dim=320) → Statistics Pooling (mean+std) → Linear → Score
28
+
29
+ For details, see our paper:
30
+ > **Something from Nothing: Data Augmentation for Robust Severity Level Estimation of Dysarthric Speech** [[arXiv]](https://arxiv.org/abs/2603.15988)
31
+
32
+ ## Available Checkpoints
33
+
34
+ This repository contains **9 checkpoints** trained with different contrastive losses:
35
+
36
+ | Checkpoint | Contrastive Loss | τ |
37
+ |---|---|---|
38
+ | `proposed_L_coarse_tau0.1` | Proposed (L_coarse) | 0.1 |
39
+ | `proposed_L_coarse_tau1.0` | Proposed (L_coarse) | 1.0 |
40
+ | **`proposed_L_coarse_tau10.0`** (default) | Proposed (L_coarse) | 10.0 |
41
+ | `proposed_L_coarse_tau50.0` | Proposed (L_coarse) | 50.0 |
42
+ | `proposed_L_coarse_tau100.0` | Proposed (L_coarse) | 100.0 |
43
+ | `proposed_L_cont_tau0.1` | Proposed (L_cont) | 0.1 |
44
+ | `proposed_L_dis_tau1.0` | Proposed (L_dis) | 1.0 |
45
+ | `rank-n-contrast_tau100.0` | Rank-N-Contrast | 100.0 |
46
+ | `simclr_tau0.1` | SimCLR | 0.1 |
47
+
48
+ ## Usage
49
+
50
+ ### With the custom pipeline
51
+
52
+ ```python
53
+ from huggingface_hub import snapshot_download
54
+
55
+ # Download the model
56
+ model_dir = snapshot_download("jaesungbae/severity-level-classifier")
57
+
58
+ # Load pipeline (defaults to proposed_L_coarse_tau10.0)
59
+ from pipeline import PreTrainedPipeline
60
+ pipe = PreTrainedPipeline(model_dir)
61
+
62
+ # Run inference
63
+ result = pipe("/path/to/audio.wav")
64
+ print(result)
65
+ # {"severity_score": 4.25, "raw_score": 4.2483, "model_name": "proposed_L_coarse_tau10.0"}
66
+ ```
67
+
68
+ ### Select a specific checkpoint
69
+
70
+ ```python
71
+ # Option 1: specify at initialization
72
+ pipe = PreTrainedPipeline(model_dir, model_name="simclr_tau0.1")
73
+
74
+ # Option 2: switch at runtime (Whisper & VAD stay loaded)
75
+ pipe.switch_model("rank-n-contrast_tau100.0")
76
+ result = pipe("/path/to/audio.wav")
77
+
78
+ # Option 3: override per call
79
+ result = pipe("/path/to/audio.wav", model_name="proposed_L_dis_tau1.0")
80
+ ```
81
+
82
+ ### List available checkpoints
83
+
84
+ ```python
85
+ print(pipe.list_models())
86
+ # ['proposed_L_coarse_tau0.1', 'proposed_L_coarse_tau1.0', ...]
87
+ ```
88
+
89
+ ### Compare all checkpoints on a single file
90
+
91
+ ```python
92
+ for name in pipe.list_models():
93
+ result = pipe("/path/to/audio.wav", model_name=name)
94
+ print(f"{name}: {result['severity_score']}")
95
+ ```
96
+
97
+ ### Standalone inference
98
+
99
+ Clone the [full repository](https://github.com/JaesungBae/DA-DSQA) and run:
100
+
101
+ ```bash
102
+ python inference.py \
103
+ --wav /path/to/audio.wav \
104
+ --checkpoint ./checkpoints/stage3/proposed_L_coarse_tau10.0/average
105
+ ```
106
+
107
+ ## Requirements
108
+
109
+ - Python 3.10+
110
+ - PyTorch + torchaudio
111
+ - transformers >= 4.40.0
112
+ - safetensors >= 0.4.0
113
+ - Silero VAD (loaded via `torch.hub` at runtime)
114
+
115
+ ## Runtime Dependencies
116
+
117
+ This model loads **openai/whisper-large-v3** (~6GB) and **Silero VAD** at initialization time. Ensure sufficient memory is available.
118
+
119
+ ## Citation
120
+
121
+ ```bibtex
122
+ @misc{bae2026something,
123
+ title = {Something from Nothing: Data Augmentation for Robust Severity Level Estimation of Dysarthric Speech},
124
+ author = {Jaesung Bae and Xiuwen Zheng and Minje Kim and Chang D. Yoo and Mark Hasegawa-Johnson},
125
+ year = {2026},
126
+ eprint = {2603.15988},
127
+ archivePrefix = {arXiv},
128
+ primaryClass = {eess.AS},
129
+ url = {https://arxiv.org/abs/2603.15988}
130
+ }
131
+ ```
checkpoints/proposed_L_coarse_tau0.1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5db1659b456f24c1f59d718e7eb53953c0d3490c634db37ae9dc0486597a844
3
+ size 2064020
checkpoints/proposed_L_coarse_tau1.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0a640d0d054fbc551004472c838a9d573dfbe9a47b218470ee54f551476559a
3
+ size 2064020
checkpoints/proposed_L_coarse_tau10.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf031ed59c2b0c127e8d0bcee38a57936aea5e32b4fc5750dc0985ffe02a2c94
3
+ size 2064020
checkpoints/proposed_L_coarse_tau100.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67574bec8e6cbaab9ea9c74f5c331a8590d08cbeb02f55c28e1eabaa7cbf6788
3
+ size 2064020
checkpoints/proposed_L_coarse_tau50.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdc5e9a6751f03db05dc74e3029308e9f1c39f379ea8496f40cb9ca65c953097
3
+ size 2064020
checkpoints/proposed_L_cont_tau0.1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c28f3c4cb5e5e5e424ae4672d362fa1a3e2642834b851c7d7e9c152036b2b9f3
3
+ size 2064020
checkpoints/proposed_L_dis_tau1.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38ea3fd55b00ae1e7ea3ddfa64e623a73bf482ba925f6c09794d7d636e9c79e9
3
+ size 2064020
checkpoints/rank-n-contrast_tau100.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3457d677fe1467f7a4376d232e1084fedcc847f29afb5511b666c886b743543f
3
+ size 2064020
checkpoints/simclr_tau0.1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f53f4e3121d9ede7c2c7c37f38a7eae1e63f3b664fca2af1a6383ff540d110e
3
+ size 2064020
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "whisper_severity_probe",
3
+ "architectures": ["WhisperFeatureProbeV2"],
4
+ "input_dim": 1280,
5
+ "proj_dim": 320,
6
+ "dropout": 0.1,
7
+ "num_classes": 1,
8
+ "whisper_model_name": "openai/whisper-large-v3",
9
+ "sampling_rate": 16000,
10
+ "default_checkpoint": "proposed_L_coarse_tau10.0",
11
+ "available_checkpoints": [
12
+ "proposed_L_coarse_tau0.1",
13
+ "proposed_L_coarse_tau1.0",
14
+ "proposed_L_coarse_tau10.0",
15
+ "proposed_L_coarse_tau50.0",
16
+ "proposed_L_coarse_tau100.0",
17
+ "proposed_L_cont_tau0.1",
18
+ "proposed_L_dis_tau1.0",
19
+ "rank-n-contrast_tau100.0",
20
+ "simclr_tau0.1"
21
+ ]
22
+ }
pipeline.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom inference pipeline for HuggingFace Hub.
3
+
4
+ Pipeline: WAV -> Silero VAD -> Whisper feature extraction -> Probe -> Severity score
5
+
6
+ Score scale: 1.0 (most severe) to 7.0 (typical speech)
7
+
8
+ Supports multiple checkpoints. Pass `model_name` to select which checkpoint to use:
9
+
10
+ pipe = PreTrainedPipeline(model_dir) # default
11
+ pipe = PreTrainedPipeline(model_dir, model_name="simclr_tau0.1") # specific
12
+
13
+ Available checkpoints:
14
+ - proposed_L_coarse_tau0.1
15
+ - proposed_L_coarse_tau1.0
16
+ - proposed_L_coarse_tau10.0 (default)
17
+ - proposed_L_coarse_tau50.0
18
+ - proposed_L_coarse_tau100.0
19
+ - proposed_L_cont_tau0.1
20
+ - proposed_L_dis_tau1.0
21
+ - rank-n-contrast_tau100.0
22
+ - simclr_tau0.1
23
+ """
24
+
25
+ import io
26
+ import json
27
+ import os
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torchaudio
32
+
33
+ SAMPLING_RATE = 16000
34
+ WHISPER_MODEL_NAME = "openai/whisper-large-v3"
35
+ WHISPER_HIDDEN_DIM = 1280
36
+ DEFAULT_CHECKPOINT = "proposed_L_coarse_tau10.0"
37
+
38
+
39
+ class WhisperFeatureProbeV2(nn.Module):
40
+ """
41
+ Regression probe on Whisper encoder features.
42
+
43
+ Architecture: LayerNorm -> Linear -> ReLU -> Dropout -> Linear -> ReLU -> Dropout
44
+ -> Statistics Pooling (mean+std) -> Linear(proj_dim*2, num_classes)
45
+ """
46
+
47
+ def __init__(self, input_dim=1280, proj_dim=256, dropout=0.1, num_classes=1):
48
+ super().__init__()
49
+ self.norm = nn.LayerNorm(input_dim)
50
+ self.projector = nn.Linear(input_dim, proj_dim)
51
+ self.projector2 = nn.Linear(proj_dim, proj_dim)
52
+ self.relu = nn.ReLU()
53
+ self.dropout = nn.Dropout(dropout)
54
+ self.classifier = nn.Linear(proj_dim * 2, num_classes)
55
+
56
+ def forward(self, input_values, lengths=None, **kwargs):
57
+ x = self.norm(input_values)
58
+ x = self.dropout(self.relu(self.projector(x)))
59
+ x = self.dropout(self.relu(self.projector2(x)))
60
+
61
+ if lengths is not None:
62
+ batch_size, max_len, _ = x.shape
63
+ mask = (
64
+ torch.arange(max_len, device=x.device).unsqueeze(0)
65
+ < lengths.unsqueeze(1)
66
+ )
67
+ mask_f = mask.unsqueeze(-1).float()
68
+ x_masked = x * mask_f
69
+ lengths_f = lengths.unsqueeze(1).float().clamp(min=1)
70
+ mean = x_masked.sum(dim=1) / lengths_f
71
+ var = (x_masked**2).sum(dim=1) / lengths_f - mean**2
72
+ std = var.clamp(min=1e-8).sqrt()
73
+ else:
74
+ mean = x.mean(dim=1)
75
+ std = x.std(dim=1)
76
+
77
+ pooled = torch.cat([mean, std], dim=1)
78
+ logits = self.classifier(pooled)
79
+
80
+ return type("Output", (), {"logits": logits, "hidden_states": pooled})()
81
+
82
+
83
+ def _load_vad():
84
+ """Load Silero VAD model."""
85
+ model, utils = torch.hub.load(
86
+ repo_or_dir="snakers4/silero-vad",
87
+ model="silero_vad",
88
+ force_reload=False,
89
+ onnx=False,
90
+ )
91
+ model.eval()
92
+ get_speech_timestamps = utils[0]
93
+ return model, get_speech_timestamps
94
+
95
+
96
+ def _apply_vad(wav, vad_model, get_speech_timestamps):
97
+ """Apply VAD and return concatenated speech segments."""
98
+ if wav.dim() > 1:
99
+ wav = wav.squeeze()
100
+
101
+ speech_timestamps = get_speech_timestamps(
102
+ wav,
103
+ vad_model,
104
+ threshold=0.5,
105
+ sampling_rate=SAMPLING_RATE,
106
+ min_speech_duration_ms=250,
107
+ min_silence_duration_ms=100,
108
+ speech_pad_ms=30,
109
+ )
110
+
111
+ if not speech_timestamps:
112
+ return wav
113
+
114
+ segments = [
115
+ wav[max(0, ts["start"]) : min(len(wav), ts["end"])]
116
+ for ts in speech_timestamps
117
+ ]
118
+ return torch.cat(segments)
119
+
120
+
121
+ def _extract_features(wav, whisper_model, processor, device):
122
+ """Extract Whisper encoder last-layer hidden states."""
123
+ if isinstance(wav, torch.Tensor):
124
+ wav_np = wav.cpu().numpy()
125
+ else:
126
+ wav_np = wav
127
+
128
+ feat_len = len(wav_np) // 320
129
+
130
+ input_features = processor(
131
+ wav_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
132
+ ).input_features.to(
133
+ device=device, dtype=next(whisper_model.parameters()).dtype
134
+ )
135
+
136
+ with torch.no_grad():
137
+ out = whisper_model.encoder(input_features, output_hidden_states=True)
138
+
139
+ return out.last_hidden_state[:, :feat_len, :].float()
140
+
141
+
142
+ def _load_probe(checkpoint_dir, device):
143
+ """Load a probe model from a checkpoint directory."""
144
+ probe = WhisperFeatureProbeV2(
145
+ input_dim=WHISPER_HIDDEN_DIM, proj_dim=320, num_classes=1
146
+ )
147
+ safe_path = os.path.join(checkpoint_dir, "model.safetensors")
148
+ bin_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
149
+ if os.path.isfile(safe_path):
150
+ from safetensors.torch import load_file
151
+
152
+ state_dict = load_file(safe_path, device=str(device))
153
+ elif os.path.isfile(bin_path):
154
+ state_dict = torch.load(
155
+ bin_path, map_location=device, weights_only=True
156
+ )
157
+ else:
158
+ raise FileNotFoundError(
159
+ f"No model.safetensors or pytorch_model.bin in {checkpoint_dir}"
160
+ )
161
+ probe.load_state_dict(state_dict)
162
+ probe.to(device).eval()
163
+ return probe
164
+
165
+
166
+ def _discover_checkpoints(path):
167
+ """Find all available checkpoint subdirectories."""
168
+ checkpoints_dir = os.path.join(path, "checkpoints")
169
+ if not os.path.isdir(checkpoints_dir):
170
+ return []
171
+ names = []
172
+ for name in sorted(os.listdir(checkpoints_dir)):
173
+ ckpt_dir = os.path.join(checkpoints_dir, name)
174
+ if os.path.isdir(ckpt_dir) and (
175
+ os.path.isfile(os.path.join(ckpt_dir, "model.safetensors"))
176
+ or os.path.isfile(os.path.join(ckpt_dir, "pytorch_model.bin"))
177
+ ):
178
+ names.append(name)
179
+ return names
180
+
181
+
182
+ class PreTrainedPipeline:
183
+ """
184
+ HuggingFace custom inference pipeline for dysarthric speech severity estimation.
185
+
186
+ Accepts a WAV file path or raw audio bytes and returns a severity score
187
+ on a 1.0 (most severe) to 7.0 (typical speech) scale.
188
+
189
+ Supports multiple checkpoints stored under `checkpoints/` in the model repo.
190
+ Use `model_name` to select which checkpoint, or call `switch_model()` to
191
+ change at runtime.
192
+
193
+ Args:
194
+ path: Path to the downloaded HuggingFace model directory.
195
+ model_name: Name of the checkpoint to load (e.g., "proposed_L_coarse_tau10.0").
196
+ If None, uses the default from config.json.
197
+ """
198
+
199
+ def __init__(self, path: str, model_name: str = None):
200
+ self.path = path
201
+ self.device = torch.device(
202
+ "cuda" if torch.cuda.is_available() else "cpu"
203
+ )
204
+
205
+ # Read config
206
+ config_path = os.path.join(path, "config.json")
207
+ if os.path.isfile(config_path):
208
+ with open(config_path) as f:
209
+ self.config = json.load(f)
210
+ else:
211
+ self.config = {}
212
+
213
+ # Discover available checkpoints
214
+ self.available_checkpoints = _discover_checkpoints(path)
215
+ if not self.available_checkpoints:
216
+ raise FileNotFoundError(
217
+ f"No checkpoints found under {os.path.join(path, 'checkpoints')}/"
218
+ )
219
+
220
+ # Load probe for the selected checkpoint
221
+ if model_name is None:
222
+ model_name = self.config.get("default_checkpoint", DEFAULT_CHECKPOINT)
223
+ self.current_model_name = None
224
+ self.probe = None
225
+ self.switch_model(model_name)
226
+
227
+ # Load Whisper encoder (shared across all checkpoints)
228
+ from transformers import WhisperFeatureExtractor, WhisperModel
229
+
230
+ self.processor = WhisperFeatureExtractor.from_pretrained(
231
+ WHISPER_MODEL_NAME
232
+ )
233
+ self.whisper = WhisperModel.from_pretrained(WHISPER_MODEL_NAME)
234
+ self.whisper.eval().to(self.device)
235
+
236
+ # Load Silero VAD (shared across all checkpoints)
237
+ self.vad_model, self.get_speech_timestamps = _load_vad()
238
+
239
+ def switch_model(self, model_name: str):
240
+ """
241
+ Switch to a different checkpoint without reloading Whisper or VAD.
242
+
243
+ Args:
244
+ model_name: Name of the checkpoint (e.g., "simclr_tau0.1")
245
+ """
246
+ if model_name == self.current_model_name:
247
+ return
248
+
249
+ if model_name not in self.available_checkpoints:
250
+ raise ValueError(
251
+ f"Checkpoint '{model_name}' not found. "
252
+ f"Available: {self.available_checkpoints}"
253
+ )
254
+
255
+ checkpoint_dir = os.path.join(self.path, "checkpoints", model_name)
256
+ self.probe = _load_probe(checkpoint_dir, self.device)
257
+ self.current_model_name = model_name
258
+
259
+ def list_models(self):
260
+ """Return list of available checkpoint names."""
261
+ return list(self.available_checkpoints)
262
+
263
+ def __call__(self, inputs, model_name: str = None):
264
+ """
265
+ Run severity estimation on audio input.
266
+
267
+ Args:
268
+ inputs: file path (str) or raw audio bytes
269
+ model_name: optionally override the checkpoint for this call
270
+
271
+ Returns:
272
+ dict with "severity_score" (clipped to 1-7), "raw_score",
273
+ and "model_name"
274
+ """
275
+ if model_name is not None:
276
+ self.switch_model(model_name)
277
+
278
+ # Load audio
279
+ if isinstance(inputs, str):
280
+ wav, sr = torchaudio.load(inputs)
281
+ elif isinstance(inputs, bytes):
282
+ wav, sr = torchaudio.load(io.BytesIO(inputs))
283
+ else:
284
+ wav, sr = torchaudio.load(io.BytesIO(inputs))
285
+
286
+ if sr != SAMPLING_RATE:
287
+ wav = torchaudio.functional.resample(wav, sr, SAMPLING_RATE)
288
+ wav = wav.squeeze()
289
+
290
+ # VAD
291
+ wav = _apply_vad(wav, self.vad_model, self.get_speech_timestamps)
292
+
293
+ # Whisper feature extraction
294
+ features = _extract_features(
295
+ wav, self.whisper, self.processor, self.device
296
+ )
297
+
298
+ # Probe inference
299
+ with torch.no_grad():
300
+ output = self.probe(features)
301
+ score = output.logits.item()
302
+
303
+ return {
304
+ "severity_score": round(max(1.0, min(7.0, score)), 2),
305
+ "raw_score": round(score, 4),
306
+ "model_name": self.current_model_name,
307
+ }