jaesungbae commited on
Commit
911f61c
·
verified ·
1 Parent(s): 8096e64

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +49 -18
  2. __pycache__/pipeline.cpython-310.pyc +0 -0
  3. config.json +1 -1
  4. pipeline.py +73 -12
  5. requirements.txt +8 -0
  6. test.py +12 -0
README.md CHANGED
@@ -17,6 +17,8 @@ A regression probe trained on top of Whisper-large-v3 encoder features for estim
17
 
18
  **Score scale:** 1.0 (most severe dysarthria) to 7.0 (typical speech)
19
 
 
 
20
  ## Model Description
21
 
22
  This model uses a three-stage training pipeline:
@@ -37,14 +39,43 @@ This repository contains **9 checkpoints** trained with different contrastive lo
37
  |---|---|---|
38
  | `proposed_L_coarse_tau0.1` | Proposed (L_coarse) | 0.1 |
39
  | `proposed_L_coarse_tau1.0` | Proposed (L_coarse) | 1.0 |
40
- | **`proposed_L_coarse_tau10.0`** (default) | Proposed (L_coarse) | 10.0 |
41
  | `proposed_L_coarse_tau50.0` | Proposed (L_coarse) | 50.0 |
42
- | `proposed_L_coarse_tau100.0` | Proposed (L_coarse) | 100.0 |
43
  | `proposed_L_cont_tau0.1` | Proposed (L_cont) | 0.1 |
44
  | `proposed_L_dis_tau1.0` | Proposed (L_dis) | 1.0 |
45
  | `rank-n-contrast_tau100.0` | Rank-N-Contrast | 100.0 |
46
  | `simclr_tau0.1` | SimCLR | 0.1 |
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ## Usage
49
 
50
  ### With the custom pipeline
@@ -53,16 +84,16 @@ This repository contains **9 checkpoints** trained with different contrastive lo
53
  from huggingface_hub import snapshot_download
54
 
55
  # Download the model
56
- model_dir = snapshot_download("jaesungbae/severity-level-classifier")
57
 
58
- # Load pipeline (defaults to proposed_L_coarse_tau10.0)
59
  from pipeline import PreTrainedPipeline
60
  pipe = PreTrainedPipeline(model_dir)
61
 
62
  # Run inference
63
  result = pipe("/path/to/audio.wav")
64
  print(result)
65
- # {"severity_score": 4.25, "raw_score": 4.2483, "model_name": "proposed_L_coarse_tau10.0"}
66
  ```
67
 
68
  ### Select a specific checkpoint
@@ -79,6 +110,18 @@ result = pipe("/path/to/audio.wav")
79
  result = pipe("/path/to/audio.wav", model_name="proposed_L_dis_tau1.0")
80
  ```
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ### List available checkpoints
83
 
84
  ```python
@@ -101,21 +144,9 @@ Clone the [full repository](https://github.com/JaesungBae/DA-DSQA) and run:
101
  ```bash
102
  python inference.py \
103
  --wav /path/to/audio.wav \
104
- --checkpoint ./checkpoints/stage3/proposed_L_coarse_tau10.0/average
105
  ```
106
 
107
- ## Requirements
108
-
109
- - Python 3.10+
110
- - PyTorch + torchaudio
111
- - transformers >= 4.40.0
112
- - safetensors >= 0.4.0
113
- - Silero VAD (loaded via `torch.hub` at runtime)
114
-
115
- ## Runtime Dependencies
116
-
117
- This model loads **openai/whisper-large-v3** (~6GB) and **Silero VAD** at initialization time. Ensure sufficient memory is available.
118
-
119
  ## Citation
120
 
121
  ```bibtex
 
17
 
18
  **Score scale:** 1.0 (most severe dysarthria) to 7.0 (typical speech)
19
 
20
+ **GitHub:** [JaesungBae/DA-DSQA](https://github.com/JaesungBae/DA-DSQA)
21
+
22
  ## Model Description
23
 
24
  This model uses a three-stage training pipeline:
 
39
  |---|---|---|
40
  | `proposed_L_coarse_tau0.1` | Proposed (L_coarse) | 0.1 |
41
  | `proposed_L_coarse_tau1.0` | Proposed (L_coarse) | 1.0 |
42
+ | `proposed_L_coarse_tau10.0` | Proposed (L_coarse) | 10.0 |
43
  | `proposed_L_coarse_tau50.0` | Proposed (L_coarse) | 50.0 |
44
+ | **`proposed_L_coarse_tau100.0`** (default) | Proposed (L_coarse) | 100.0 |
45
  | `proposed_L_cont_tau0.1` | Proposed (L_cont) | 0.1 |
46
  | `proposed_L_dis_tau1.0` | Proposed (L_dis) | 1.0 |
47
  | `rank-n-contrast_tau100.0` | Rank-N-Contrast | 100.0 |
48
  | `simclr_tau0.1` | SimCLR | 0.1 |
49
 
50
+ ## Setup
51
+
52
+ ### 1. Create conda environment
53
+
54
+ ```bash
55
+ conda create -n da-dsqa python=3.10 -y
56
+ conda activate da-dsqa
57
+ ```
58
+
59
+ ### 2. Install PyTorch with CUDA
60
+
61
+ ```bash
62
+ conda install pytorch torchaudio -c pytorch -y
63
+ ```
64
+
65
+ > For a GPU build with a specific CUDA version, see [pytorch.org](https://pytorch.org/get-started/locally/) for the appropriate command.
66
+
67
+ ### 3. Install remaining dependencies
68
+
69
+ ```bash
70
+ pip install -r requirements.txt
71
+ ```
72
+
73
+ > **Note:** [Silero VAD](https://github.com/snakers4/silero-vad) is loaded automatically at runtime via `torch.hub` — no separate installation needed.
74
+
75
+ ### Runtime Dependencies
76
+
77
+ This model loads **openai/whisper-large-v3** (~6GB) and **Silero VAD** at initialization time. Ensure sufficient memory is available.
78
+
79
  ## Usage
80
 
81
  ### With the custom pipeline
 
84
  from huggingface_hub import snapshot_download
85
 
86
  # Download the model
87
+ model_dir = snapshot_download("jaesungbae/da-dsqa")
88
 
89
+ # Load pipeline (defaults to proposed_L_coarse_tau100.0)
90
  from pipeline import PreTrainedPipeline
91
  pipe = PreTrainedPipeline(model_dir)
92
 
93
  # Run inference
94
  result = pipe("/path/to/audio.wav")
95
  print(result)
96
+ # {"severity_score": 4.25, "raw_score": 4.2483, "model_name": "proposed_L_coarse_tau100.0"}
97
  ```
98
 
99
  ### Select a specific checkpoint
 
110
  result = pipe("/path/to/audio.wav", model_name="proposed_L_dis_tau1.0")
111
  ```
112
 
113
+ ### Batch inference
114
+
115
+ ```python
116
+ results = pipe.batch_inference([
117
+ "/path/to/audio1.wav",
118
+ "/path/to/audio2.wav",
119
+ "/path/to/audio3.wav",
120
+ ])
121
+ for r in results:
122
+ print(f"{r['file']}: {r['severity_score']}")
123
+ ```
124
+
125
  ### List available checkpoints
126
 
127
  ```python
 
144
  ```bash
145
  python inference.py \
146
  --wav /path/to/audio.wav \
147
+ --checkpoint ./checkpoints/stage3/proposed_L_coarse_tau100.0/average
148
  ```
149
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ## Citation
151
 
152
  ```bibtex
__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (9.13 kB). View file
 
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "num_classes": 1,
8
  "whisper_model_name": "openai/whisper-large-v3",
9
  "sampling_rate": 16000,
10
- "default_checkpoint": "proposed_L_coarse_tau10.0",
11
  "available_checkpoints": [
12
  "proposed_L_coarse_tau0.1",
13
  "proposed_L_coarse_tau1.0",
 
7
  "num_classes": 1,
8
  "whisper_model_name": "openai/whisper-large-v3",
9
  "sampling_rate": 16000,
10
+ "default_checkpoint": "proposed_L_coarse_tau100.0",
11
  "available_checkpoints": [
12
  "proposed_L_coarse_tau0.1",
13
  "proposed_L_coarse_tau1.0",
pipeline.py CHANGED
@@ -28,12 +28,13 @@ import os
28
 
29
  import torch
30
  import torch.nn as nn
 
31
  import torchaudio
32
 
33
  SAMPLING_RATE = 16000
34
  WHISPER_MODEL_NAME = "openai/whisper-large-v3"
35
  WHISPER_HIDDEN_DIM = 1280
36
- DEFAULT_CHECKPOINT = "proposed_L_coarse_tau10.0"
37
 
38
 
39
  class WhisperFeatureProbeV2(nn.Module):
@@ -260,6 +261,19 @@ class PreTrainedPipeline:
260
  """Return list of available checkpoint names."""
261
  return list(self.available_checkpoints)
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  def __call__(self, inputs, model_name: str = None):
264
  """
265
  Run severity estimation on audio input.
@@ -275,17 +289,7 @@ class PreTrainedPipeline:
275
  if model_name is not None:
276
  self.switch_model(model_name)
277
 
278
- # Load audio
279
- if isinstance(inputs, str):
280
- wav, sr = torchaudio.load(inputs)
281
- elif isinstance(inputs, bytes):
282
- wav, sr = torchaudio.load(io.BytesIO(inputs))
283
- else:
284
- wav, sr = torchaudio.load(io.BytesIO(inputs))
285
-
286
- if sr != SAMPLING_RATE:
287
- wav = torchaudio.functional.resample(wav, sr, SAMPLING_RATE)
288
- wav = wav.squeeze()
289
 
290
  # VAD
291
  wav = _apply_vad(wav, self.vad_model, self.get_speech_timestamps)
@@ -305,3 +309,60 @@ class PreTrainedPipeline:
305
  "raw_score": round(score, 4),
306
  "model_name": self.current_model_name,
307
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  import torch
30
  import torch.nn as nn
31
+ import soundfile as sf
32
  import torchaudio
33
 
34
  SAMPLING_RATE = 16000
35
  WHISPER_MODEL_NAME = "openai/whisper-large-v3"
36
  WHISPER_HIDDEN_DIM = 1280
37
+ DEFAULT_CHECKPOINT = "proposed_L_coarse_tau100.0"
38
 
39
 
40
  class WhisperFeatureProbeV2(nn.Module):
 
261
  """Return list of available checkpoint names."""
262
  return list(self.available_checkpoints)
263
 
264
+ def _load_wav(self, inputs):
265
+ """Load and preprocess a single audio input to a 1D waveform tensor."""
266
+ if isinstance(inputs, (bytes, bytearray)):
267
+ data, sr = sf.read(io.BytesIO(inputs), dtype="float32")
268
+ else:
269
+ data, sr = sf.read(inputs, dtype="float32")
270
+ wav = torch.from_numpy(data).float()
271
+ if wav.dim() > 1:
272
+ wav = wav.mean(dim=-1)
273
+ if sr != SAMPLING_RATE:
274
+ wav = torchaudio.functional.resample(wav, sr, SAMPLING_RATE)
275
+ return wav
276
+
277
  def __call__(self, inputs, model_name: str = None):
278
  """
279
  Run severity estimation on audio input.
 
289
  if model_name is not None:
290
  self.switch_model(model_name)
291
 
292
+ wav = self._load_wav(inputs)
 
 
 
 
 
 
 
 
 
 
293
 
294
  # VAD
295
  wav = _apply_vad(wav, self.vad_model, self.get_speech_timestamps)
 
309
  "raw_score": round(score, 4),
310
  "model_name": self.current_model_name,
311
  }
312
+
313
+ def batch_inference(self, input_list, model_name: str = None):
314
+ """
315
+ Run severity estimation on a batch of audio files.
316
+
317
+ Whisper processes one file at a time (due to variable-length VAD output),
318
+ but the probe runs as a single padded batch for efficiency.
319
+
320
+ Args:
321
+ input_list: list of file paths (str) or raw audio bytes
322
+ model_name: optionally override the checkpoint for this call
323
+
324
+ Returns:
325
+ list of dicts, each with "file", "severity_score", "raw_score",
326
+ and "model_name"
327
+ """
328
+ if model_name is not None:
329
+ self.switch_model(model_name)
330
+
331
+ # Extract features for each file
332
+ all_features = []
333
+ lengths = []
334
+ for inputs in input_list:
335
+ wav = self._load_wav(inputs)
336
+ wav = _apply_vad(wav, self.vad_model, self.get_speech_timestamps)
337
+ features = _extract_features(
338
+ wav, self.whisper, self.processor, self.device
339
+ )
340
+ feat = features.squeeze(0) # (T, hidden_dim)
341
+ all_features.append(feat)
342
+ lengths.append(feat.shape[0])
343
+
344
+ # Pad and batch
345
+ max_len = max(lengths)
346
+ hidden_dim = all_features[0].shape[1]
347
+ batch_size = len(all_features)
348
+
349
+ padded = torch.zeros(batch_size, max_len, hidden_dim, device=self.device)
350
+ for i, feat in enumerate(all_features):
351
+ padded[i, : lengths[i]] = feat
352
+ lengths_tensor = torch.tensor(lengths, device=self.device)
353
+
354
+ # Batched probe inference
355
+ with torch.no_grad():
356
+ output = self.probe(padded, lengths=lengths_tensor)
357
+ scores = output.logits.squeeze(-1).cpu().tolist()
358
+
359
+ results = []
360
+ for i, inputs in enumerate(input_list):
361
+ score = scores[i]
362
+ results.append({
363
+ "file": inputs if isinstance(inputs, str) else f"input_{i}",
364
+ "severity_score": round(max(1.0, min(7.0, score)), 2),
365
+ "raw_score": round(score, 4),
366
+ "model_name": self.current_model_name,
367
+ })
368
+ return results
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Install PyTorch separately first (via conda or pip):
2
+ # conda install pytorch torchaudio -c pytorch -y
3
+ # See https://pytorch.org/get-started/locally/ for GPU builds.
4
+
5
+ transformers>=4.40.0
6
+ safetensors>=0.4.0
7
+ huggingface_hub>=0.20.0
8
+ soundfile>=0.12.0
test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ # Download the model from HuggingFace
4
+ model_dir = snapshot_download("jaesungbae/da-dsqa")
5
+
6
+ # Load pipeline (defaults to proposed_L_coarse_tau100.0)
7
+ from pipeline import PreTrainedPipeline
8
+ pipe = PreTrainedPipeline(model_dir)
9
+
10
+ # Run inference
11
+ result = pipe("/projects/bedl/jbae4/workspace_2026/severity_level_classifier_release/sample_wavs/Naturalness/level_1/d1b9444a-2ed1-438e-fd68-08dcb5d1edd7_1071_8831.wav")
12
+ print(result)