DhruvB1906 commited on
Commit
f468cc8
·
verified ·
1 Parent(s): 4e9a3bc

Upload folder using huggingface_hub

Browse files
training/augmentation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio data augmentation for training."""
2
+
3
+ import numpy as np
4
+ import librosa
5
+
6
+
7
+ class AudioAugmenter:
8
+ """Apply audio augmentations for data diversity."""
9
+
10
+ def __init__(
11
+ self,
12
+ time_stretch_range=(0.8, 1.2),
13
+ pitch_shift_range=(-2, 2),
14
+ noise_level_range=(0.005, 0.015),
15
+ apply_prob=0.5,
16
+ ):
17
+ """
18
+ Initialize augmenter.
19
+
20
+ Args:
21
+ time_stretch_range: (min_rate, max_rate) for time stretching
22
+ pitch_shift_range: (min_steps, max_steps) for pitch shifting
23
+ noise_level_range: (min_level, max_level) for additive noise
24
+ apply_prob: Probability of applying each augmentation
25
+ """
26
+ self.time_stretch_range = time_stretch_range
27
+ self.pitch_shift_range = pitch_shift_range
28
+ self.noise_level_range = noise_level_range
29
+ self.apply_prob = apply_prob
30
+
31
+ def augment(self, waveform: np.ndarray, sr: int) -> np.ndarray:
32
+ """
33
+ Apply random augmentations.
34
+
35
+ Args:
36
+ waveform: Audio waveform
37
+ sr: Sample rate
38
+
39
+ Returns:
40
+ Augmented waveform
41
+ """
42
+ # Time stretching
43
+ if np.random.rand() < self.apply_prob:
44
+ rate = np.random.uniform(*self.time_stretch_range)
45
+ waveform = librosa.effects.time_stretch(waveform, rate=rate)
46
+
47
+ # Pitch shifting
48
+ if np.random.rand() < self.apply_prob:
49
+ n_steps = np.random.uniform(*self.pitch_shift_range)
50
+ waveform = librosa.effects.pitch_shift(waveform, sr=sr, n_steps=n_steps)
51
+
52
+ # Additive noise
53
+ if np.random.rand() < self.apply_prob:
54
+ noise_level = np.random.uniform(*self.noise_level_range)
55
+ noise = np.random.randn(len(waveform)) * noise_level
56
+ waveform = waveform + noise
57
+
58
+ # SpecAugment (applied at spectrogram level, not here)
59
+
60
+ return waveform
training/calibrate.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calibrate model probabilities using Platt scaling.
3
+
4
+ This script:
5
+ 1. Loads the ensemble model
6
+ 2. Collects predictions on a held-out calibration set
7
+ 3. Fits Platt scaling parameters (a, b) via logistic regression
8
+ 4. Evaluates calibration quality (ECE, reliability diagrams)
9
+
10
+ Usage:
11
+ python training/calibrate.py
12
+ """
13
+
14
+ import sys
15
+ from pathlib import Path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.utils.data import DataLoader
21
+ import mlflow
22
+ import numpy as np
23
+ from tqdm import tqdm
24
+ import yaml
25
+ import logging
26
+ from sklearn.linear_model import LogisticRegression
27
+ from sklearn.metrics import brier_score_loss, log_loss
28
+ import matplotlib.pyplot as plt
29
+
30
+ from training.dataset import DysarthriaDataset
31
+ from training.train_hubert_salr import HuBERTSALRModel
32
+ from training.train_cnn_bilstm import CNNBiLSTMTransformer
33
+
34
+
35
+ # ══════════════════════════════════════════════════════════════════════════════
36
+ # Calibration Metrics
37
+ # ══════════════════════════════════════════════════════════════════════════════
38
+
39
+ def expected_calibration_error(y_true, y_prob, n_bins=10):
40
+ """
41
+ Compute Expected Calibration Error (ECE).
42
+
43
+ ECE measures the difference between predicted confidence and actual accuracy.
44
+ Lower ECE indicates better calibration.
45
+
46
+ Args:
47
+ y_true: True labels (0 or 1)
48
+ y_prob: Predicted probabilities (0 to 1)
49
+ n_bins: Number of bins for binning predictions
50
+
51
+ Returns:
52
+ ECE value
53
+ """
54
+ bin_edges = np.linspace(0, 1, n_bins + 1)
55
+ bin_indices = np.digitize(y_prob, bin_edges[:-1]) - 1
56
+ bin_indices = np.clip(bin_indices, 0, n_bins - 1)
57
+
58
+ ece = 0.0
59
+ for i in range(n_bins):
60
+ mask = bin_indices == i
61
+ if mask.sum() > 0:
62
+ bin_acc = y_true[mask].mean()
63
+ bin_conf = y_prob[mask].mean()
64
+ bin_weight = mask.sum() / len(y_true)
65
+ ece += bin_weight * np.abs(bin_acc - bin_conf)
66
+
67
+ return ece
68
+
69
+
70
+ def reliability_curve(y_true, y_prob, n_bins=10):
71
+ """
72
+ Compute reliability curve data for plotting.
73
+
74
+ Returns:
75
+ bin_centers, bin_accuracies, bin_confidences, bin_counts
76
+ """
77
+ bin_edges = np.linspace(0, 1, n_bins + 1)
78
+ bin_indices = np.digitize(y_prob, bin_edges[:-1]) - 1
79
+ bin_indices = np.clip(bin_indices, 0, n_bins - 1)
80
+
81
+ bin_centers = []
82
+ bin_accuracies = []
83
+ bin_confidences = []
84
+ bin_counts = []
85
+
86
+ for i in range(n_bins):
87
+ mask = bin_indices == i
88
+ if mask.sum() > 0:
89
+ bin_centers.append((bin_edges[i] + bin_edges[i + 1]) / 2)
90
+ bin_accuracies.append(y_true[mask].mean())
91
+ bin_confidences.append(y_prob[mask].mean())
92
+ bin_counts.append(mask.sum())
93
+
94
+ return (
95
+ np.array(bin_centers),
96
+ np.array(bin_accuracies),
97
+ np.array(bin_confidences),
98
+ np.array(bin_counts),
99
+ )
100
+
101
+
102
+ # ══════════════════════════════════════════════════════════════════════════════
103
+ # Model Inference
104
+ # ══════════════════════════════════════════════════════════════════════════════
105
+
106
+ def collect_predictions(hubert_model, cnn_model, dataloader, alpha, device):
107
+ """
108
+ Collect raw logits and probabilities from ensemble.
109
+
110
+ Args:
111
+ hubert_model: HuBERT-SALR model
112
+ cnn_model: CNN-BiLSTM model
113
+ dataloader: Data loader
114
+ alpha: Ensemble mixing weight
115
+ device: torch device
116
+
117
+ Returns:
118
+ logits, probabilities, true labels (all numpy arrays)
119
+ """
120
+ all_logits = []
121
+ all_probs = []
122
+ all_labels = []
123
+
124
+ hubert_model.eval()
125
+ cnn_model.eval()
126
+
127
+ with torch.no_grad():
128
+ for batch in tqdm(dataloader, desc="Collecting predictions"):
129
+ waveform = batch["waveform"].to(device)
130
+ spectrogram = batch["spectrogram"].to(device)
131
+ labels = batch["label"]
132
+
133
+ # Ensemble logits
134
+ hubert_logits = hubert_model(waveform)
135
+ cnn_logits = cnn_model(spectrogram)
136
+ ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
137
+
138
+ # Probabilities (uncalibrated)
139
+ probs = torch.softmax(ensemble_logits, dim=1)[:, 1]
140
+
141
+ all_logits.extend(ensemble_logits[:, 1].cpu().numpy())
142
+ all_probs.extend(probs.cpu().numpy())
143
+ all_labels.extend(labels.numpy())
144
+
145
+ return (
146
+ np.array(all_logits),
147
+ np.array(all_probs),
148
+ np.array(all_labels),
149
+ )
150
+
151
+
152
+ # ══════════════════════════════════════════════════════════════════════════════
153
+ # Platt Scaling
154
+ # ══════════════════════════════════════════════════════════════════════════════
155
+
156
+ def fit_platt_scaling(logits, labels):
157
+ """
158
+ Fit Platt scaling parameters.
159
+
160
+ Platt scaling fits:
161
+ calibrated_prob = sigmoid(a * logit + b)
162
+
163
+ Args:
164
+ logits: Raw model logits (n_samples,)
165
+ labels: True binary labels (n_samples,)
166
+
167
+ Returns:
168
+ a, b parameters
169
+ """
170
+ # Reshape for sklearn
171
+ X = logits.reshape(-1, 1)
172
+ y = labels
173
+
174
+ # Fit logistic regression (no regularization)
175
+ lr = LogisticRegression(penalty=None, solver="lbfgs", max_iter=1000)
176
+ lr.fit(X, y)
177
+
178
+ a = lr.coef_[0][0]
179
+ b = lr.intercept_[0]
180
+
181
+ return a, b
182
+
183
+
184
+ def apply_platt_scaling(logits, a, b):
185
+ """Apply Platt scaling to logits."""
186
+ z = a * logits + b
187
+ calibrated_probs = 1 / (1 + np.exp(-z))
188
+ return calibrated_probs
189
+
190
+
191
+ # ══════════════════════════════════════════════════════════════════════════════
192
+ # Visualization
193
+ # ══════════════════════════════════════════════════════════════════════════════
194
+
195
+ def plot_reliability_diagram(
196
+ y_true,
197
+ y_prob_uncal,
198
+ y_prob_cal,
199
+ output_path: Path,
200
+ ):
201
+ """Plot reliability diagram comparing uncalibrated vs calibrated."""
202
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
203
+
204
+ for ax, probs, title in zip(
205
+ axes,
206
+ [y_prob_uncal, y_prob_cal],
207
+ ["Uncalibrated", "Calibrated"],
208
+ ):
209
+ centers, accs, confs, counts = reliability_curve(y_true, probs, n_bins=10)
210
+
211
+ # Plot reliability curve
212
+ ax.plot([0, 1], [0, 1], "k--", label="Perfect calibration", linewidth=2)
213
+ ax.scatter(confs, accs, s=counts * 3, alpha=0.6, label="Model", zorder=5)
214
+ ax.plot(confs, accs, "o-", linewidth=2, markersize=8)
215
+
216
+ # Compute ECE
217
+ ece = expected_calibration_error(y_true, probs)
218
+ brier = brier_score_loss(y_true, probs)
219
+
220
+ ax.set_xlabel("Mean Predicted Probability", fontsize=12)
221
+ ax.set_ylabel("Fraction of Positives", fontsize=12)
222
+ ax.set_title(f"{title}\nECE: {ece:.4f}, Brier: {brier:.4f}", fontsize=14)
223
+ ax.legend(fontsize=10)
224
+ ax.grid(True, alpha=0.3)
225
+ ax.set_xlim([0, 1])
226
+ ax.set_ylim([0, 1])
227
+
228
+ plt.tight_layout()
229
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
230
+ plt.close()
231
+
232
+
233
+ def plot_histogram_comparison(
234
+ y_true,
235
+ y_prob_uncal,
236
+ y_prob_cal,
237
+ output_path: Path,
238
+ ):
239
+ """Plot histogram of predicted probabilities."""
240
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
241
+
242
+ # Split by true label
243
+ mask_positive = y_true == 1
244
+ mask_negative = y_true == 0
245
+
246
+ for i, (probs, title) in enumerate(
247
+ [(y_prob_uncal, "Uncalibrated"), (y_prob_cal, "Calibrated")]
248
+ ):
249
+ # Positive class
250
+ axes[i, 0].hist(probs[mask_positive], bins=20, alpha=0.7, color="red", edgecolor="black")
251
+ axes[i, 0].set_xlabel("Predicted Probability", fontsize=12)
252
+ axes[i, 0].set_ylabel("Count", fontsize=12)
253
+ axes[i, 0].set_title(f"{title} - True Dysarthric", fontsize=14)
254
+ axes[i, 0].grid(True, alpha=0.3)
255
+
256
+ # Negative class
257
+ axes[i, 1].hist(probs[mask_negative], bins=20, alpha=0.7, color="blue", edgecolor="black")
258
+ axes[i, 1].set_xlabel("Predicted Probability", fontsize=12)
259
+ axes[i, 1].set_ylabel("Count", fontsize=12)
260
+ axes[i, 1].set_title(f"{title} - True Healthy", fontsize=14)
261
+ axes[i, 1].grid(True, alpha=0.3)
262
+
263
+ plt.tight_layout()
264
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
265
+ plt.close()
266
+
267
+
268
+ # ══════════════════════════════════════════════════════════════════════════════
269
+ # Main
270
+ # ══════════════════════════════════════════════════════════════════════════════
271
+
272
+ def main():
273
+ logging.basicConfig(level=logging.INFO)
274
+ logger = logging.getLogger(__name__)
275
+
276
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277
+ logger.info(f"Using device: {device}")
278
+
279
+ # ──────────────────────────────────────────────────────────────────────────
280
+ # Load Models
281
+ # ──────────────────────────────────────────────────────────────────────────
282
+ logger.info("Loading models...")
283
+
284
+ hubert_checkpoint = Path("models/hubert_salr_best.pt")
285
+ cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
286
+
287
+ hubert_model = HuBERTSALRModel()
288
+ hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"])
289
+ hubert_model.to(device)
290
+
291
+ cnn_model = CNNBiLSTMTransformer()
292
+ cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"])
293
+ cnn_model.to(device)
294
+
295
+ # Load optimal alpha
296
+ with open("configs/model_config.yaml") as f:
297
+ config = yaml.safe_load(f)
298
+ alpha = config.get("ensemble", {}).get("alpha", 0.6)
299
+ logger.info(f"Using ensemble alpha: {alpha}")
300
+
301
+ # ──────────────────────────────────────────────────────────────────────────
302
+ # Load Calibration Data (use validation set)
303
+ # ──────────────────────────────────────────────────────────────────────────
304
+ val_manifest = Path("data/manifests/val.csv")
305
+ val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
306
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
307
+
308
+ logger.info(f"Calibration samples: {len(val_dataset)}")
309
+
310
+ # ──────────────────────────────────────────────────────────────────────────
311
+ # Collect Predictions
312
+ # ──────────────────────────────────────────────────────────────────────────
313
+ logger.info("Collecting predictions...")
314
+
315
+ logits, probs_uncal, labels = collect_predictions(
316
+ hubert_model, cnn_model, val_loader, alpha, device
317
+ )
318
+
319
+ # ──────────────────────────────────────────────────────────────────────────
320
+ # Fit Platt Scaling
321
+ # ──────────────────────────────────────────────────────────────────────────
322
+ mlflow.set_experiment("model_calibration")
323
+
324
+ with mlflow.start_run():
325
+ logger.info("\nFitting Platt scaling...")
326
+
327
+ a, b = fit_platt_scaling(logits, labels)
328
+ logger.info(f"Platt parameters: a={a:.6f}, b={b:.6f}")
329
+
330
+ # Apply calibration
331
+ probs_cal = apply_platt_scaling(logits, a, b)
332
+
333
+ # ──────────────────────────────────────────────────────────────────────
334
+ # Evaluate Calibration
335
+ # ──────────────────────────────────────────────────────────────────────
336
+ ece_uncal = expected_calibration_error(labels, probs_uncal)
337
+ ece_cal = expected_calibration_error(labels, probs_cal)
338
+
339
+ brier_uncal = brier_score_loss(labels, probs_uncal)
340
+ brier_cal = brier_score_loss(labels, probs_cal)
341
+
342
+ logloss_uncal = log_loss(labels, probs_uncal)
343
+ logloss_cal = log_loss(labels, probs_cal)
344
+
345
+ logger.info("\n" + "=" * 80)
346
+ logger.info("CALIBRATION RESULTS")
347
+ logger.info("=" * 80)
348
+ logger.info(f"Expected Calibration Error (ECE):")
349
+ logger.info(f" Uncalibrated: {ece_uncal:.4f}")
350
+ logger.info(f" Calibrated: {ece_cal:.4f} ({'↓' if ece_cal < ece_uncal else '↑'} {abs(ece_cal - ece_uncal):.4f})")
351
+ logger.info(f"\nBrier Score:")
352
+ logger.info(f" Uncalibrated: {brier_uncal:.4f}")
353
+ logger.info(f" Calibrated: {brier_cal:.4f} ({'↓' if brier_cal < brier_uncal else '↑'} {abs(brier_cal - brier_uncal):.4f})")
354
+ logger.info(f"\nLog Loss:")
355
+ logger.info(f" Uncalibrated: {logloss_uncal:.4f}")
356
+ logger.info(f" Calibrated: {logloss_cal:.4f} ({'↓' if logloss_cal < logloss_uncal else '↑'} {abs(logloss_cal - logloss_uncal):.4f})")
357
+ logger.info("=" * 80)
358
+
359
+ # Log to MLflow
360
+ mlflow.log_params({
361
+ "platt_a": a,
362
+ "platt_b": b,
363
+ "calibration_samples": len(labels),
364
+ })
365
+
366
+ mlflow.log_metrics({
367
+ "ece_uncalibrated": ece_uncal,
368
+ "ece_calibrated": ece_cal,
369
+ "ece_improvement": ece_uncal - ece_cal,
370
+ "brier_uncalibrated": brier_uncal,
371
+ "brier_calibrated": brier_cal,
372
+ "logloss_uncalibrated": logloss_uncal,
373
+ "logloss_calibrated": logloss_cal,
374
+ })
375
+
376
+ # ──────────────────────────────────────────────────────────────────────
377
+ # Save Results
378
+ # ──────────────────────────────────────────────────────────────────────
379
+ output_dir = Path("reports/calibration")
380
+ output_dir.mkdir(parents=True, exist_ok=True)
381
+
382
+ # Save Platt parameters
383
+ calibration_config = {
384
+ "platt_scaling": {
385
+ "a": float(a),
386
+ "b": float(b),
387
+ "ece_uncalibrated": float(ece_uncal),
388
+ "ece_calibrated": float(ece_cal),
389
+ "brier_uncalibrated": float(brier_uncal),
390
+ "brier_calibrated": float(brier_cal),
391
+ }
392
+ }
393
+
394
+ config_path = output_dir / "calibration_params.yaml"
395
+ with open(config_path, "w") as f:
396
+ yaml.dump(calibration_config, f, default_flow_style=False)
397
+ mlflow.log_artifact(str(config_path))
398
+ logger.info(f"\n✓ Calibration parameters saved to {config_path}")
399
+
400
+ # Plot reliability diagram
401
+ reliability_path = output_dir / "reliability_diagram.png"
402
+ plot_reliability_diagram(labels, probs_uncal, probs_cal, reliability_path)
403
+ mlflow.log_artifact(str(reliability_path))
404
+ logger.info(f"✓ Reliability diagram saved to {reliability_path}")
405
+
406
+ # Plot histogram comparison
407
+ hist_path = output_dir / "probability_histograms.png"
408
+ plot_histogram_comparison(labels, probs_uncal, probs_cal, hist_path)
409
+ mlflow.log_artifact(str(hist_path))
410
+ logger.info(f"✓ Probability histograms saved to {hist_path}")
411
+
412
+ logger.info("\n✓ Calibration complete!")
413
+ logger.info(f" Update configs/model_config.yaml with Platt parameters:")
414
+ logger.info(f" a: {a:.6f}")
415
+ logger.info(f" b: {b:.6f}")
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()
training/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dataset for dysarthria detection."""
2
+
3
+ import logging
4
+ import pandas as pd
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from pathlib import Path
9
+
10
+ from src.ingestion.audio_loader import AudioLoader
11
+ from src.ingestion.preprocessor import AudioPreprocessor
12
+ from src.features.mfcc_extractor import MFCCExtractor
13
+ from src.features.prosodic_extractor import ProsodicExtractor
14
+ from src.features.formant_extractor import FormantExtractor
15
+ from src.features.egemaps_extractor import EGeMAPSExtractor
16
+ from src.features.spectrogram_builder import SpectrogramBuilder
17
+ from src.features.feature_fusion import FeatureFusion
18
+ from src.features.schemas import FeatureBundle
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class DysarthriaDataset(Dataset):
24
+ """Dataset for dysarthria detection with on-the-fly feature extraction."""
25
+
26
+ def __init__(
27
+ self,
28
+ manifest_path: str | Path,
29
+ augment: bool = False,
30
+ cache_features: bool = False,
31
+ ):
32
+ """
33
+ Initialize dataset.
34
+
35
+ Args:
36
+ manifest_path: Path to CSV manifest (filepath, label, speaker_id, duration)
37
+ augment: Apply data augmentation
38
+ cache_features: Cache extracted features in memory
39
+ """
40
+ self.manifest = pd.read_csv(manifest_path)
41
+ self.augment = augment
42
+ self.cache_features = cache_features
43
+ self.feature_cache = {} if cache_features else None
44
+
45
+ # Initialize components
46
+ self.audio_loader = AudioLoader()
47
+ self.preprocessor = AudioPreprocessor(target_sr=16000)
48
+ self.mfcc_extractor = MFCCExtractor()
49
+ self.prosodic_extractor = ProsodicExtractor()
50
+ self.formant_extractor = FormantExtractor()
51
+ self.egemaps_extractor = EGeMAPSExtractor()
52
+ self.spectrogram_builder = SpectrogramBuilder()
53
+ self.feature_fusion = FeatureFusion()
54
+
55
+ logger.info(f"Dataset initialized: {len(self)} samples")
56
+
57
+ def __len__(self) -> int:
58
+ return len(self.manifest)
59
+
60
+ def __getitem__(self, idx: int) -> dict:
61
+ """
62
+ Get item by index.
63
+
64
+ Returns:
65
+ dict with keys:
66
+ - waveform: torch.Tensor (samples,)
67
+ - spectrogram: torch.Tensor (2, freq, time)
68
+ - acoustic_features: torch.Tensor (n_features,)
69
+ - label: torch.Tensor (1,)
70
+ - speaker_id: str
71
+ """
72
+ # Check cache
73
+ if self.cache_features and idx in self.feature_cache:
74
+ return self.feature_cache[idx]
75
+
76
+ # Load sample info
77
+ row = self.manifest.iloc[idx]
78
+ audio_path = row["file_path"] # Changed from "filepath" to "file_path"
79
+ label = int(row["label"])
80
+ speaker_id = row["speaker_id"]
81
+
82
+ try:
83
+ # Load and preprocess audio
84
+ audio_input, waveform = self.audio_loader.load(audio_path)
85
+ preprocessed = self.preprocessor.process(
86
+ waveform,
87
+ sr=audio_input.sample_rate, # Use original SR
88
+ original_duration=row["duration"],
89
+ )
90
+
91
+ waveform = preprocessed.waveform
92
+ sr = preprocessed.sample_rate
93
+
94
+ # Apply augmentation if training
95
+ if self.augment:
96
+ waveform = self._apply_augmentation(waveform, sr)
97
+
98
+ # Extract features
99
+ mfcc = self.mfcc_extractor.extract(waveform, sr)
100
+ prosody = self.prosodic_extractor.extract(waveform, sr)
101
+ formants = self.formant_extractor.extract(waveform, sr)
102
+ egemaps = self.egemaps_extractor.extract(waveform, sr)
103
+ spectrogram = self.spectrogram_builder.build(waveform, sr)
104
+
105
+ # Create feature bundle
106
+ feature_bundle = FeatureBundle(
107
+ waveform=waveform,
108
+ sample_rate=sr,
109
+ duration_sec=preprocessed.duration_sec,
110
+ mfcc=mfcc,
111
+ prosody=prosody,
112
+ formants=formants,
113
+ egemaps=egemaps,
114
+ spectrogram=spectrogram,
115
+ )
116
+
117
+ # Fuse acoustic features
118
+ feature_bundle = self.feature_fusion.fuse(feature_bundle)
119
+
120
+ # Convert to tensors
121
+ item = {
122
+ "waveform": torch.from_numpy(waveform).float(),
123
+ "spectrogram": torch.from_numpy(spectrogram.stacked).float(),
124
+ "acoustic_features": torch.from_numpy(feature_bundle.fused_acoustic).float(),
125
+ "label": torch.tensor([label], dtype=torch.long),
126
+ "speaker_id": speaker_id,
127
+ }
128
+
129
+ # Cache if enabled
130
+ if self.cache_features:
131
+ self.feature_cache[idx] = item
132
+
133
+ return item
134
+
135
+ except Exception as e:
136
+ logger.error(f"Failed to load sample {idx} ({audio_path}): {e}")
137
+ # Return a dummy sample
138
+ return self._get_dummy_item(label, speaker_id)
139
+
140
+ def _apply_augmentation(self, waveform: np.ndarray, sr: int) -> np.ndarray:
141
+ """Apply data augmentation."""
142
+ from training.augmentation import AudioAugmenter
143
+
144
+ augmenter = AudioAugmenter()
145
+ return augmenter.augment(waveform, sr)
146
+
147
+ def _get_dummy_item(self, label: int, speaker_id: str) -> dict:
148
+ """Return a dummy item when loading fails."""
149
+ return {
150
+ "waveform": torch.zeros(16000 * 10), # 10 seconds of silence
151
+ "spectrogram": torch.zeros(2, 128, 313),
152
+ "acoustic_features": torch.zeros(145),
153
+ "label": torch.tensor([label], dtype=torch.long),
154
+ "speaker_id": speaker_id,
155
+ }
156
+
157
+
158
+ def collate_fn(batch: list[dict]) -> dict:
159
+ """
160
+ Collate function for DataLoader.
161
+
162
+ Handles variable-length sequences by padding.
163
+ """
164
+ # Find max lengths
165
+ max_waveform_len = max(item["waveform"].shape[0] for item in batch)
166
+ max_time_frames = max(item["spectrogram"].shape[2] for item in batch)
167
+
168
+ # Pad sequences
169
+ waveforms = []
170
+ spectrograms = []
171
+ acoustic_features = []
172
+ labels = []
173
+ speaker_ids = []
174
+
175
+ for item in batch:
176
+ # Pad waveform
177
+ waveform = item["waveform"]
178
+ if waveform.shape[0] < max_waveform_len:
179
+ waveform = torch.nn.functional.pad(
180
+ waveform, (0, max_waveform_len - waveform.shape[0])
181
+ )
182
+ waveforms.append(waveform)
183
+
184
+ # Pad spectrogram
185
+ spec = item["spectrogram"]
186
+ if spec.shape[2] < max_time_frames:
187
+ spec = torch.nn.functional.pad(
188
+ spec, (0, max_time_frames - spec.shape[2])
189
+ )
190
+ spectrograms.append(spec)
191
+
192
+ acoustic_features.append(item["acoustic_features"])
193
+ labels.append(item["label"])
194
+ speaker_ids.append(item["speaker_id"])
195
+
196
+ return {
197
+ "waveform": torch.stack(waveforms),
198
+ "spectrogram": torch.stack(spectrograms),
199
+ "acoustic_features": torch.stack(acoustic_features),
200
+ "label": torch.stack(labels),
201
+ "speaker_id": speaker_ids,
202
+ }
training/evaluate.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive model evaluation on test set.
3
+
4
+ This script:
5
+ 1. Loads trained ensemble model with calibration
6
+ 2. Evaluates on held-out test set
7
+ 3. Computes classification metrics (accuracy, F1, AUC, sensitivity, specificity)
8
+ 4. Generates confusion matrix, ROC curve, PR curve
9
+ 5. Performs error analysis
10
+
11
+ Usage:
12
+ python training/evaluate.py
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.utils.data import DataLoader
22
+ import mlflow
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+ import yaml
26
+ import pandas as pd
27
+ import logging
28
+ from sklearn.metrics import (
29
+ accuracy_score,
30
+ f1_score,
31
+ roc_auc_score,
32
+ confusion_matrix,
33
+ classification_report,
34
+ roc_curve,
35
+ precision_recall_curve,
36
+ average_precision_score,
37
+ )
38
+ import matplotlib.pyplot as plt
39
+ import seaborn as sns
40
+
41
+ from training.dataset import DysarthriaDataset
42
+ from training.train_hubert_salr import HuBERTSALRModel
43
+ from training.train_cnn_bilstm import CNNBiLSTMTransformer
44
+
45
+
46
+ # ══════════════════════════════════════════════════════════════════════════════
47
+ # Model Inference
48
+ # ══════════════════════════════════════════════════════════════════════════════
49
+
50
+ def evaluate_model(hubert_model, cnn_model, dataloader, alpha, platt_a, platt_b, device):
51
+ """
52
+ Evaluate calibrated ensemble on test set.
53
+
54
+ Returns:
55
+ predictions, probabilities, labels, file_paths
56
+ """
57
+ all_preds = []
58
+ all_probs = []
59
+ all_labels = []
60
+ all_files = []
61
+
62
+ hubert_model.eval()
63
+ cnn_model.eval()
64
+
65
+ with torch.no_grad():
66
+ for batch in tqdm(dataloader, desc="Evaluating"):
67
+ waveform = batch["waveform"].to(device)
68
+ spectrogram = batch["spectrogram"].to(device)
69
+ labels = batch["label"]
70
+ file_paths = batch.get("file_path", [""] * len(labels))
71
+
72
+ # Ensemble logits
73
+ hubert_logits = hubert_model(waveform)
74
+ cnn_logits = cnn_model(spectrogram)
75
+ ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
76
+
77
+ # Apply Platt scaling
78
+ raw_logits = ensemble_logits[:, 1].cpu().numpy()
79
+ z = platt_a * raw_logits + platt_b
80
+ calibrated_probs = 1 / (1 + np.exp(-z))
81
+
82
+ # Predictions
83
+ preds = (calibrated_probs > 0.5).astype(int)
84
+
85
+ all_preds.extend(preds)
86
+ all_probs.extend(calibrated_probs)
87
+ all_labels.extend(labels.numpy())
88
+ all_files.extend(file_paths)
89
+
90
+ return (
91
+ np.array(all_preds),
92
+ np.array(all_probs),
93
+ np.array(all_labels),
94
+ all_files,
95
+ )
96
+
97
+
98
+ # ══════════════════════════════════════════════════════════════════════════════
99
+ # Metrics Computation
100
+ # ══════════════════════════════════════════════════════════════════════════════
101
+
102
+ def compute_metrics(y_true, y_pred, y_prob):
103
+ """Compute comprehensive classification metrics."""
104
+ # Basic metrics
105
+ accuracy = accuracy_score(y_true, y_pred)
106
+ f1 = f1_score(y_true, y_pred, average="binary")
107
+ auc = roc_auc_score(y_true, y_prob)
108
+ ap = average_precision_score(y_true, y_prob)
109
+
110
+ # Confusion matrix
111
+ cm = confusion_matrix(y_true, y_pred)
112
+ tn, fp, fn, tp = cm.ravel()
113
+
114
+ # Sensitivity and specificity
115
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
116
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
117
+
118
+ # Positive and negative predictive value
119
+ ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
120
+ npv = tn / (tn + fn) if (tn + fn) > 0 else 0
121
+
122
+ return {
123
+ "accuracy": accuracy,
124
+ "f1": f1,
125
+ "auc": auc,
126
+ "average_precision": ap,
127
+ "sensitivity": sensitivity,
128
+ "specificity": specificity,
129
+ "ppv": ppv,
130
+ "npv": npv,
131
+ "tp": int(tp),
132
+ "tn": int(tn),
133
+ "fp": int(fp),
134
+ "fn": int(fn),
135
+ "confusion_matrix": cm,
136
+ }
137
+
138
+
139
+ # ══════════════════════════════════════════════════════════════════════════════
140
+ # Visualization
141
+ # ══════════════════════════════════════════════════════════════════════════════
142
+
143
+ def plot_confusion_matrix(cm, output_path: Path):
144
+ """Plot confusion matrix."""
145
+ plt.figure(figsize=(8, 6))
146
+ sns.heatmap(
147
+ cm,
148
+ annot=True,
149
+ fmt="d",
150
+ cmap="Blues",
151
+ xticklabels=["Healthy", "Dysarthric"],
152
+ yticklabels=["Healthy", "Dysarthric"],
153
+ cbar_kws={"label": "Count"},
154
+ )
155
+ plt.title("Confusion Matrix - Test Set", fontsize=16, fontweight="bold")
156
+ plt.ylabel("True Label", fontsize=14)
157
+ plt.xlabel("Predicted Label", fontsize=14)
158
+ plt.tight_layout()
159
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
160
+ plt.close()
161
+
162
+
163
+ def plot_roc_curve(y_true, y_prob, auc_score, output_path: Path):
164
+ """Plot ROC curve."""
165
+ fpr, tpr, thresholds = roc_curve(y_true, y_prob)
166
+
167
+ plt.figure(figsize=(8, 6))
168
+ plt.plot(fpr, tpr, linewidth=2, label=f"Model (AUC = {auc_score:.4f})")
169
+ plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random Classifier")
170
+
171
+ plt.xlabel("False Positive Rate", fontsize=14)
172
+ plt.ylabel("True Positive Rate", fontsize=14)
173
+ plt.title("ROC Curve - Test Set", fontsize=16, fontweight="bold")
174
+ plt.legend(fontsize=12)
175
+ plt.grid(True, alpha=0.3)
176
+ plt.tight_layout()
177
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
178
+ plt.close()
179
+
180
+
181
+ def plot_precision_recall_curve(y_true, y_prob, ap_score, output_path: Path):
182
+ """Plot Precision-Recall curve."""
183
+ precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
184
+
185
+ plt.figure(figsize=(8, 6))
186
+ plt.plot(recall, precision, linewidth=2, label=f"Model (AP = {ap_score:.4f})")
187
+
188
+ plt.xlabel("Recall", fontsize=14)
189
+ plt.ylabel("Precision", fontsize=14)
190
+ plt.title("Precision-Recall Curve - Test Set", fontsize=16, fontweight="bold")
191
+ plt.legend(fontsize=12)
192
+ plt.grid(True, alpha=0.3)
193
+ plt.xlim([0, 1])
194
+ plt.ylim([0, 1])
195
+ plt.tight_layout()
196
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
197
+ plt.close()
198
+
199
+
200
+ def plot_probability_distribution(y_true, y_prob, output_path: Path):
201
+ """Plot distribution of predicted probabilities by class."""
202
+ plt.figure(figsize=(10, 6))
203
+
204
+ mask_positive = y_true == 1
205
+ mask_negative = y_true == 0
206
+
207
+ plt.hist(
208
+ y_prob[mask_negative],
209
+ bins=30,
210
+ alpha=0.6,
211
+ color="blue",
212
+ label="Healthy",
213
+ edgecolor="black",
214
+ )
215
+ plt.hist(
216
+ y_prob[mask_positive],
217
+ bins=30,
218
+ alpha=0.6,
219
+ color="red",
220
+ label="Dysarthric",
221
+ edgecolor="black",
222
+ )
223
+
224
+ plt.axvline(0.5, color="black", linestyle="--", linewidth=2, label="Decision Threshold")
225
+
226
+ plt.xlabel("Predicted Probability", fontsize=14)
227
+ plt.ylabel("Count", fontsize=14)
228
+ plt.title("Predicted Probability Distribution - Test Set", fontsize=16, fontweight="bold")
229
+ plt.legend(fontsize=12)
230
+ plt.grid(True, alpha=0.3, axis="y")
231
+ plt.tight_layout()
232
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
233
+ plt.close()
234
+
235
+
236
+ # ══════════════════════════════════════════════════════════════════════════════
237
+ # Error Analysis
238
+ # ══════════════════════════════════════════════════════════════════════════════
239
+
240
+ def perform_error_analysis(y_true, y_pred, y_prob, file_paths, output_path: Path):
241
+ """Identify and save misclassified samples."""
242
+ errors = []
243
+
244
+ for i, (true_label, pred_label, prob, file_path) in enumerate(
245
+ zip(y_true, y_pred, y_prob, file_paths)
246
+ ):
247
+ if true_label != pred_label:
248
+ error_type = "False Positive" if pred_label == 1 else "False Negative"
249
+ confidence = prob if pred_label == 1 else (1 - prob)
250
+
251
+ errors.append({
252
+ "file_path": file_path,
253
+ "true_label": "Dysarthric" if true_label == 1 else "Healthy",
254
+ "predicted_label": "Dysarthric" if pred_label == 1 else "Healthy",
255
+ "probability": prob,
256
+ "confidence": confidence,
257
+ "error_type": error_type,
258
+ })
259
+
260
+ errors_df = pd.DataFrame(errors)
261
+ errors_df = errors_df.sort_values("confidence", ascending=False)
262
+ errors_df.to_csv(output_path, index=False)
263
+
264
+ return errors_df
265
+
266
+
267
+ # ══════════════════════════════════════════════════════════════════════════════
268
+ # Main
269
+ # ══════════════════════════════════════════════════════════════════════════════
270
+
271
+ def main():
272
+ logging.basicConfig(level=logging.INFO)
273
+ logger = logging.getLogger(__name__)
274
+
275
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
276
+ logger.info(f"Using device: {device}")
277
+
278
+ # ──────────────────────────────────────────────────────────────────────────
279
+ # Load Configuration
280
+ # ──────────────────────────────────────────────────────────────────────────
281
+ with open("configs/model_config.yaml") as f:
282
+ config = yaml.safe_load(f)
283
+
284
+ alpha = config.get("ensemble", {}).get("alpha", 0.6)
285
+
286
+ # Load Platt scaling parameters
287
+ calibration_file = Path("reports/calibration/calibration_params.yaml")
288
+ if calibration_file.exists():
289
+ with open(calibration_file) as f:
290
+ cal_config = yaml.safe_load(f)
291
+ platt_a = cal_config["platt_scaling"]["a"]
292
+ platt_b = cal_config["platt_scaling"]["b"]
293
+ logger.info(f"Loaded Platt parameters: a={platt_a:.6f}, b={platt_b:.6f}")
294
+ else:
295
+ platt_a, platt_b = 1.0, 0.0
296
+ logger.warning("Calibration parameters not found, using identity mapping")
297
+
298
+ # ──────────────────────────────────────────────────────────────────────────
299
+ # Load Models
300
+ # ──────────────────────────────────────────────────────────────────────────
301
+ logger.info("Loading models...")
302
+
303
+ hubert_checkpoint = Path("models/hubert_salr_best.pt")
304
+ cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
305
+
306
+ hubert_model = HuBERTSALRModel()
307
+ hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"])
308
+ hubert_model.to(device)
309
+
310
+ cnn_model = CNNBiLSTMTransformer()
311
+ cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"])
312
+ cnn_model.to(device)
313
+
314
+ # ──────────────────────────────────────────────────────────────────────────
315
+ # Load Test Data
316
+ # ──────────────────────────────────────────────────────────────────────────
317
+ test_manifest = Path("data/manifests/test.csv")
318
+ test_dataset = DysarthriaDataset(test_manifest, augmentor=None, mode="test")
319
+ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)
320
+
321
+ logger.info(f"Test samples: {len(test_dataset)}")
322
+
323
+ # ──────────────────────────────────────────────────────────────────────────
324
+ # Evaluate
325
+ # ──────────────────────────────────────────────────────────────────────────
326
+ mlflow.set_experiment("model_evaluation")
327
+
328
+ with mlflow.start_run():
329
+ logger.info("\nEvaluating on test set...")
330
+
331
+ y_pred, y_prob, y_true, file_paths = evaluate_model(
332
+ hubert_model, cnn_model, test_loader, alpha, platt_a, platt_b, device
333
+ )
334
+
335
+ # Compute metrics
336
+ metrics = compute_metrics(y_true, y_pred, y_prob)
337
+
338
+ # ──────────────────────────────────────────────────────────────────────
339
+ # Print Results
340
+ # ──────────────────────────────────────────────────────────────────────
341
+ logger.info("\n" + "=" * 80)
342
+ logger.info("TEST SET EVALUATION RESULTS")
343
+ logger.info("=" * 80)
344
+ logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
345
+ logger.info(f"F1 Score: {metrics['f1']:.4f}")
346
+ logger.info(f"AUC-ROC: {metrics['auc']:.4f}")
347
+ logger.info(f"Average Precision: {metrics['average_precision']:.4f}")
348
+ logger.info(f"Sensitivity: {metrics['sensitivity']:.4f}")
349
+ logger.info(f"Specificity: {metrics['specificity']:.4f}")
350
+ logger.info(f"PPV: {metrics['ppv']:.4f}")
351
+ logger.info(f"NPV: {metrics['npv']:.4f}")
352
+ logger.info("")
353
+ logger.info("Confusion Matrix:")
354
+ logger.info(f" True Negatives: {metrics['tn']}")
355
+ logger.info(f" False Positives: {metrics['fp']}")
356
+ logger.info(f" False Negatives: {metrics['fn']}")
357
+ logger.info(f" True Positives: {metrics['tp']}")
358
+ logger.info("=" * 80)
359
+
360
+ # Log to MLflow
361
+ mlflow.log_params({
362
+ "ensemble_alpha": alpha,
363
+ "platt_a": platt_a,
364
+ "platt_b": platt_b,
365
+ "test_samples": len(y_true),
366
+ })
367
+
368
+ mlflow.log_metrics({
369
+ "test_accuracy": metrics["accuracy"],
370
+ "test_f1": metrics["f1"],
371
+ "test_auc": metrics["auc"],
372
+ "test_ap": metrics["average_precision"],
373
+ "test_sensitivity": metrics["sensitivity"],
374
+ "test_specificity": metrics["specificity"],
375
+ "test_ppv": metrics["ppv"],
376
+ "test_npv": metrics["npv"],
377
+ })
378
+
379
+ # ──────────────────────────────────────────────────────────────────────
380
+ # Save Results
381
+ # ──────────────────────────────────────────────────────────────────────
382
+ output_dir = Path("reports/evaluation")
383
+ output_dir.mkdir(parents=True, exist_ok=True)
384
+
385
+ # Save metrics
386
+ metrics_file = output_dir / "test_metrics.yaml"
387
+ with open(metrics_file, "w") as f:
388
+ # Convert numpy types to Python types
389
+ metrics_to_save = {k: v for k, v in metrics.items() if k != "confusion_matrix"}
390
+ yaml.dump(metrics_to_save, f, default_flow_style=False)
391
+ mlflow.log_artifact(str(metrics_file))
392
+ logger.info(f"\n✓ Metrics saved to {metrics_file}")
393
+
394
+ # Classification report
395
+ report = classification_report(
396
+ y_true,
397
+ y_pred,
398
+ target_names=["Healthy", "Dysarthric"],
399
+ digits=4,
400
+ )
401
+ report_file = output_dir / "classification_report.txt"
402
+ with open(report_file, "w") as f:
403
+ f.write(report)
404
+ mlflow.log_artifact(str(report_file))
405
+ logger.info(f"✓ Classification report saved to {report_file}")
406
+
407
+ # Confusion matrix
408
+ cm_path = output_dir / "confusion_matrix.png"
409
+ plot_confusion_matrix(metrics["confusion_matrix"], cm_path)
410
+ mlflow.log_artifact(str(cm_path))
411
+ logger.info(f"✓ Confusion matrix plot saved to {cm_path}")
412
+
413
+ # ROC curve
414
+ roc_path = output_dir / "roc_curve.png"
415
+ plot_roc_curve(y_true, y_prob, metrics["auc"], roc_path)
416
+ mlflow.log_artifact(str(roc_path))
417
+ logger.info(f"✓ ROC curve saved to {roc_path}")
418
+
419
+ # Precision-Recall curve
420
+ pr_path = output_dir / "precision_recall_curve.png"
421
+ plot_precision_recall_curve(y_true, y_prob, metrics["average_precision"], pr_path)
422
+ mlflow.log_artifact(str(pr_path))
423
+ logger.info(f"✓ Precision-Recall curve saved to {pr_path}")
424
+
425
+ # Probability distribution
426
+ prob_dist_path = output_dir / "probability_distribution.png"
427
+ plot_probability_distribution(y_true, y_prob, prob_dist_path)
428
+ mlflow.log_artifact(str(prob_dist_path))
429
+ logger.info(f"✓ Probability distribution saved to {prob_dist_path}")
430
+
431
+ # Error analysis
432
+ errors_file = output_dir / "misclassified_samples.csv"
433
+ errors_df = perform_error_analysis(y_true, y_pred, y_prob, file_paths, errors_file)
434
+ mlflow.log_artifact(str(errors_file))
435
+ logger.info(f"✓ Error analysis saved to {errors_file}")
436
+ logger.info(f" Total errors: {len(errors_df)}")
437
+ logger.info(f" False Positives: {len(errors_df[errors_df['error_type'] == 'False Positive'])}")
438
+ logger.info(f" False Negatives: {len(errors_df[errors_df['error_type'] == 'False Negative'])}")
439
+
440
+ logger.info("\n✓ Evaluation complete!")
441
+
442
+
443
+ if __name__ == "__main__":
444
+ main()
training/train_cnn_bilstm.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for CNN-BiLSTM-Transformer model (spectrogram branch).
3
+
4
+ This model processes log-mel spectrograms and CWT scalograms through:
5
+ 1. CNN feature extraction (ResNet-style blocks)
6
+ 2. BiLSTM temporal modeling
7
+ 3. Transformer encoder with self-attention
8
+ 4. Classification head
9
+
10
+ Usage:
11
+ python training/train_cnn_bilstm.py
12
+ """
13
+
14
+ import sys
15
+ from pathlib import Path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.utils.data import DataLoader
22
+ import mlflow
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+ import yaml
26
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
27
+ import logging
28
+
29
+ from training.dataset import DysarthriaDataset
30
+ from training.augmentation import AudioAugmentor
31
+
32
+ # ══════════════════════════════════════════════════════════════════════════════
33
+ # CNN-BiLSTM-Transformer Model Architecture
34
+ # ══════════════════════════════════════════════════════════════════════════════
35
+
36
+ class ResidualBlock(nn.Module):
37
+ """Residual block for CNN feature extraction."""
38
+
39
+ def __init__(self, in_channels: int, out_channels: int):
40
+ super().__init__()
41
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
42
+ self.bn1 = nn.BatchNorm2d(out_channels)
43
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
44
+ self.bn2 = nn.BatchNorm2d(out_channels)
45
+
46
+ # Skip connection with 1x1 conv if dimensions change
47
+ self.skip = nn.Identity()
48
+ if in_channels != out_channels:
49
+ self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)
50
+
51
+ def forward(self, x):
52
+ residual = self.skip(x)
53
+ x = torch.relu(self.bn1(self.conv1(x)))
54
+ x = self.bn2(self.conv2(x))
55
+ return torch.relu(x + residual)
56
+
57
+
58
+ class CNNBiLSTMTransformer(nn.Module):
59
+ """
60
+ Spectrogram-based dysarthria detection model.
61
+
62
+ Architecture:
63
+ - CNN: Extract spatial features from spectrogram
64
+ - BiLSTM: Model temporal dependencies
65
+ - Transformer: Self-attention for long-range patterns
66
+ - Classifier: Binary classification head
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ input_channels: int = 2, # Log-mel + CWT
72
+ cnn_channels: list = [64, 128, 256],
73
+ lstm_hidden: int = 256,
74
+ transformer_heads: int = 8,
75
+ transformer_layers: int = 4,
76
+ dropout: float = 0.3,
77
+ ):
78
+ super().__init__()
79
+
80
+ # ─────────────────────────────────────────────────────────────────────
81
+ # CNN Feature Extractor
82
+ # ─────────────────────────────────────────────────────────────────────
83
+ self.cnn_blocks = nn.ModuleList()
84
+ in_ch = input_channels
85
+ for out_ch in cnn_channels:
86
+ self.cnn_blocks.append(ResidualBlock(in_ch, out_ch))
87
+ in_ch = out_ch
88
+
89
+ self.pool = nn.AdaptiveAvgPool2d((None, 1)) # Pool frequency dimension
90
+
91
+ # ─────────────────────────────────────────────────────────────────────
92
+ # BiLSTM Temporal Modeling
93
+ # ─────────────────────────────────────────────────────────────────────
94
+ self.lstm = nn.LSTM(
95
+ input_size=cnn_channels[-1],
96
+ hidden_size=lstm_hidden,
97
+ num_layers=2,
98
+ batch_first=True,
99
+ bidirectional=True,
100
+ dropout=dropout,
101
+ )
102
+
103
+ # ─────────────────────────────────────────────────────────────────────
104
+ # Transformer Encoder
105
+ # ─────────────────────────────────────────────────────────────────────
106
+ encoder_layer = nn.TransformerEncoderLayer(
107
+ d_model=lstm_hidden * 2, # Bidirectional
108
+ nhead=transformer_heads,
109
+ dim_feedforward=lstm_hidden * 4,
110
+ dropout=dropout,
111
+ batch_first=True,
112
+ )
113
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
114
+
115
+ # ─────────────────────────────────────────────────────────────────────
116
+ # Classification Head
117
+ # ─────────────────────────────────────────────────────────────────────
118
+ self.classifier = nn.Sequential(
119
+ nn.Linear(lstm_hidden * 2, 512),
120
+ nn.ReLU(),
121
+ nn.Dropout(dropout),
122
+ nn.Linear(512, 256),
123
+ nn.ReLU(),
124
+ nn.Dropout(dropout),
125
+ nn.Linear(256, 2), # Binary: healthy vs dysarthric
126
+ )
127
+
128
+ def forward(self, spectrogram):
129
+ """
130
+ Args:
131
+ spectrogram: (batch, 2, freq, time) - Log-mel + CWT
132
+
133
+ Returns:
134
+ logits: (batch, 2)
135
+ attention_weights: Transformer attention for explainability
136
+ """
137
+ batch_size = spectrogram.size(0)
138
+
139
+ # CNN feature extraction
140
+ x = spectrogram
141
+ for block in self.cnn_blocks:
142
+ x = block(x)
143
+
144
+ # Pool frequency dimension: (batch, channels, freq, time) → (batch, channels, time)
145
+ x = self.pool(x).squeeze(2)
146
+
147
+ # Transpose for LSTM: (batch, time, channels)
148
+ x = x.transpose(1, 2)
149
+
150
+ # BiLSTM
151
+ x, _ = self.lstm(x)
152
+
153
+ # Transformer encoder
154
+ x = self.transformer(x)
155
+
156
+ # Global average pooling over time
157
+ x = x.mean(dim=1) # (batch, lstm_hidden*2)
158
+
159
+ # Classification
160
+ logits = self.classifier(x)
161
+
162
+ return logits
163
+
164
+
165
+ # ══════════════════════════════════════════════════════════════════════════════
166
+ # Training Loop
167
+ # ══════════════════════════════════════════════════════════════════════════════
168
+
169
+ def train_epoch(model, dataloader, optimizer, criterion, device):
170
+ """Train for one epoch."""
171
+ model.train()
172
+ total_loss = 0
173
+ all_preds = []
174
+ all_labels = []
175
+
176
+ for batch in tqdm(dataloader, desc="Training"):
177
+ spectrogram = batch["spectrogram"].to(device)
178
+ labels = batch["label"].to(device)
179
+
180
+ # Forward pass
181
+ optimizer.zero_grad()
182
+ logits = model(spectrogram)
183
+ loss = criterion(logits, labels)
184
+
185
+ # Backward pass
186
+ loss.backward()
187
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
188
+ optimizer.step()
189
+
190
+ # Metrics
191
+ total_loss += loss.item()
192
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
193
+ all_preds.extend(preds)
194
+ all_labels.extend(labels.cpu().numpy())
195
+
196
+ avg_loss = total_loss / len(dataloader)
197
+ accuracy = accuracy_score(all_labels, all_preds)
198
+ f1 = f1_score(all_labels, all_preds, average="binary")
199
+
200
+ return avg_loss, accuracy, f1
201
+
202
+
203
+ def validate(model, dataloader, criterion, device):
204
+ """Validate the model."""
205
+ model.eval()
206
+ total_loss = 0
207
+ all_preds = []
208
+ all_probs = []
209
+ all_labels = []
210
+
211
+ with torch.no_grad():
212
+ for batch in tqdm(dataloader, desc="Validating"):
213
+ spectrogram = batch["spectrogram"].to(device)
214
+ labels = batch["label"].to(device)
215
+
216
+ logits = model(spectrogram)
217
+ loss = criterion(logits, labels)
218
+
219
+ total_loss += loss.item()
220
+ probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
221
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
222
+
223
+ all_preds.extend(preds)
224
+ all_probs.extend(probs)
225
+ all_labels.extend(labels.cpu().numpy())
226
+
227
+ avg_loss = total_loss / len(dataloader)
228
+ accuracy = accuracy_score(all_labels, all_preds)
229
+ f1 = f1_score(all_labels, all_preds, average="binary")
230
+ auc = roc_auc_score(all_labels, all_probs)
231
+
232
+ return avg_loss, accuracy, f1, auc
233
+
234
+
235
+ def main():
236
+ # ──────────────────────────────────────────────────────────────────────────
237
+ # Setup
238
+ # ──────────────────────────────────────────────────────────────────────────
239
+ logging.basicConfig(level=logging.INFO)
240
+ logger = logging.getLogger(__name__)
241
+
242
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
243
+ logger.info(f"Using device: {device}")
244
+
245
+ # Load config
246
+ config_path = Path("configs/model_config.yaml")
247
+ with open(config_path) as f:
248
+ config = yaml.safe_load(f)
249
+
250
+ # MLflow setup
251
+ mlflow.set_experiment("cnn_bilstm_transformer_training")
252
+
253
+ # ──────────────────────────────────────────────────────────────────────────
254
+ # Data Loading
255
+ # ──────────────────────────────────────────────────────────────────────────
256
+ train_manifest = Path("data/manifests/train.csv")
257
+ val_manifest = Path("data/manifests/val.csv")
258
+
259
+ augmentor = AudioAugmentor(
260
+ time_stretch_range=(0.9, 1.1),
261
+ pitch_shift_range=(-2, 2),
262
+ noise_level=0.005,
263
+ )
264
+
265
+ train_dataset = DysarthriaDataset(train_manifest, augmentor=augmentor, mode="train")
266
+ val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
267
+
268
+ train_loader = DataLoader(
269
+ train_dataset,
270
+ batch_size=config.get("cnn_bilstm", {}).get("batch_size", 16),
271
+ shuffle=True,
272
+ num_workers=4,
273
+ pin_memory=True,
274
+ )
275
+
276
+ val_loader = DataLoader(
277
+ val_dataset,
278
+ batch_size=config.get("cnn_bilstm", {}).get("batch_size", 16),
279
+ shuffle=False,
280
+ num_workers=4,
281
+ pin_memory=True,
282
+ )
283
+
284
+ logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
285
+
286
+ # ──────────────────────────────────────────────────────────────────────────
287
+ # Model Setup
288
+ # ──────────────────────────────────────────────────────────────────────────
289
+ model = CNNBiLSTMTransformer(
290
+ input_channels=2,
291
+ cnn_channels=[64, 128, 256],
292
+ lstm_hidden=256,
293
+ transformer_heads=8,
294
+ transformer_layers=4,
295
+ dropout=0.3,
296
+ ).to(device)
297
+
298
+ logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
299
+
300
+ # Optimizer and scheduler
301
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
302
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
303
+ optimizer, mode="min", factor=0.5, patience=3, verbose=True
304
+ )
305
+
306
+ # Loss function with class weights (handle imbalance)
307
+ criterion = nn.CrossEntropyLoss()
308
+
309
+ # ──────────────────────────────────────────────────────────────────────────
310
+ # Training Loop
311
+ # ──────────────────────────────────────────────────────────────────────────
312
+ num_epochs = config.get("cnn_bilstm", {}).get("epochs", 30)
313
+ best_val_auc = 0
314
+ best_model_path = Path("models/cnn_bilstm_best.pt")
315
+ best_model_path.parent.mkdir(parents=True, exist_ok=True)
316
+
317
+ with mlflow.start_run():
318
+ # Log hyperparameters
319
+ mlflow.log_params({
320
+ "model": "cnn_bilstm_transformer",
321
+ "epochs": num_epochs,
322
+ "batch_size": config.get("cnn_bilstm", {}).get("batch_size", 16),
323
+ "learning_rate": 1e-4,
324
+ "optimizer": "AdamW",
325
+ })
326
+
327
+ for epoch in range(1, num_epochs + 1):
328
+ logger.info(f"\nEpoch {epoch}/{num_epochs}")
329
+
330
+ # Train
331
+ train_loss, train_acc, train_f1 = train_epoch(
332
+ model, train_loader, optimizer, criterion, device
333
+ )
334
+
335
+ # Validate
336
+ val_loss, val_acc, val_f1, val_auc = validate(
337
+ model, val_loader, criterion, device
338
+ )
339
+
340
+ # Learning rate scheduling
341
+ scheduler.step(val_loss)
342
+
343
+ # Logging
344
+ logger.info(
345
+ f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}"
346
+ )
347
+ logger.info(
348
+ f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, AUC: {val_auc:.4f}"
349
+ )
350
+
351
+ mlflow.log_metrics({
352
+ "train_loss": train_loss,
353
+ "train_accuracy": train_acc,
354
+ "train_f1": train_f1,
355
+ "val_loss": val_loss,
356
+ "val_accuracy": val_acc,
357
+ "val_f1": val_f1,
358
+ "val_auc": val_auc,
359
+ "learning_rate": optimizer.param_groups[0]["lr"],
360
+ }, step=epoch)
361
+
362
+ # Save best model
363
+ if val_auc > best_val_auc:
364
+ best_val_auc = val_auc
365
+ torch.save({
366
+ "epoch": epoch,
367
+ "model_state_dict": model.state_dict(),
368
+ "optimizer_state_dict": optimizer.state_dict(),
369
+ "val_auc": val_auc,
370
+ }, best_model_path)
371
+ logger.info(f"✓ New best model saved (AUC: {val_auc:.4f})")
372
+ mlflow.log_artifact(str(best_model_path))
373
+
374
+ logger.info(f"\n✓ Training complete! Best validation AUC: {best_val_auc:.4f}")
375
+ mlflow.log_metric("best_val_auc", best_val_auc)
376
+
377
+
378
+ if __name__ == "__main__":
379
+ main()
training/train_ensemble_weights.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Optimize ensemble weights between HuBERT-SALR and CNN-BiLSTM models.
3
+
4
+ This script performs grid search to find the optimal alpha (mixing weight):
5
+ ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
6
+
7
+ Usage:
8
+ python training/train_ensemble_weights.py
9
+ """
10
+
11
+ import sys
12
+ from pathlib import Path
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader
18
+ import mlflow
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+ import yaml
22
+ import pandas as pd
23
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
24
+ import logging
25
+ import matplotlib.pyplot as plt
26
+ import seaborn as sns
27
+
28
+ from training.dataset import DysarthriaDataset
29
+
30
+
31
+ # ══════════════════════════════════════════════════════════════════════════════
32
+ # Model Loading Utilities
33
+ # ══════════════════════════════════════════════════════════════════════════════
34
+
35
+ def load_hubert_salr(checkpoint_path: Path, device):
36
+ """Load trained HuBERT-SALR model."""
37
+ from training.train_hubert_salr import HuBERTSALRModel
38
+
39
+ model = HuBERTSALRModel()
40
+ checkpoint = torch.load(checkpoint_path, map_location=device)
41
+ model.load_state_dict(checkpoint["model_state_dict"])
42
+ model.to(device)
43
+ model.eval()
44
+ return model
45
+
46
+
47
+ def load_cnn_bilstm(checkpoint_path: Path, device):
48
+ """Load trained CNN-BiLSTM model."""
49
+ from training.train_cnn_bilstm import CNNBiLSTMTransformer
50
+
51
+ model = CNNBiLSTMTransformer()
52
+ checkpoint = torch.load(checkpoint_path, map_location=device)
53
+ model.load_state_dict(checkpoint["model_state_dict"])
54
+ model.to(device)
55
+ model.eval()
56
+ return model
57
+
58
+
59
+ # ══════════════════════════════════════════════════════════════════════════════
60
+ # Ensemble Evaluation
61
+ # ══════════════════════════════════════════════════════════════════════════════
62
+
63
+ def evaluate_ensemble(
64
+ hubert_model,
65
+ cnn_model,
66
+ dataloader,
67
+ alpha: float,
68
+ device,
69
+ ):
70
+ """
71
+ Evaluate ensemble with given alpha weight.
72
+
73
+ Args:
74
+ hubert_model: HuBERT-SALR model
75
+ cnn_model: CNN-BiLSTM model
76
+ dataloader: Validation data
77
+ alpha: Mixing weight (0 to 1)
78
+ device: torch device
79
+
80
+ Returns:
81
+ Dict of metrics
82
+ """
83
+ all_preds = []
84
+ all_probs = []
85
+ all_labels = []
86
+
87
+ with torch.no_grad():
88
+ for batch in tqdm(dataloader, desc=f"Alpha={alpha:.2f}", leave=False):
89
+ waveform = batch["waveform"].to(device)
90
+ spectrogram = batch["spectrogram"].to(device)
91
+ labels = batch["label"].to(device)
92
+
93
+ # Get predictions from both models
94
+ hubert_logits = hubert_model(waveform)
95
+ cnn_logits = cnn_model(spectrogram)
96
+
97
+ # Ensemble
98
+ ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
99
+
100
+ # Convert to predictions
101
+ probs = torch.softmax(ensemble_logits, dim=1)[:, 1].cpu().numpy()
102
+ preds = torch.argmax(ensemble_logits, dim=1).cpu().numpy()
103
+
104
+ all_preds.extend(preds)
105
+ all_probs.extend(probs)
106
+ all_labels.extend(labels.cpu().numpy())
107
+
108
+ # Compute metrics
109
+ accuracy = accuracy_score(all_labels, all_preds)
110
+ f1 = f1_score(all_labels, all_preds, average="binary")
111
+ auc = roc_auc_score(all_labels, all_probs)
112
+ cm = confusion_matrix(all_labels, all_preds)
113
+
114
+ # Compute sensitivity and specificity
115
+ tn, fp, fn, tp = cm.ravel()
116
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
117
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
118
+
119
+ return {
120
+ "alpha": alpha,
121
+ "accuracy": accuracy,
122
+ "f1": f1,
123
+ "auc": auc,
124
+ "sensitivity": sensitivity,
125
+ "specificity": specificity,
126
+ "confusion_matrix": cm,
127
+ }
128
+
129
+
130
+ # ══════════════════════════════════════════════════════════════════════════════
131
+ # Grid Search
132
+ # ══════════════════════════════════════════════════════════════════════════════
133
+
134
+ def grid_search_alpha(
135
+ hubert_model,
136
+ cnn_model,
137
+ dataloader,
138
+ device,
139
+ alpha_range=(0.0, 1.0),
140
+ num_points=21,
141
+ ):
142
+ """
143
+ Perform grid search over alpha values.
144
+
145
+ Args:
146
+ hubert_model: HuBERT-SALR model
147
+ cnn_model: CNN-BiLSTM model
148
+ dataloader: Validation data
149
+ device: torch device
150
+ alpha_range: (min, max) alpha values
151
+ num_points: Number of alpha values to test
152
+
153
+ Returns:
154
+ DataFrame with results for each alpha
155
+ """
156
+ alphas = np.linspace(alpha_range[0], alpha_range[1], num_points)
157
+ results = []
158
+
159
+ for alpha in alphas:
160
+ metrics = evaluate_ensemble(hubert_model, cnn_model, dataloader, alpha, device)
161
+ results.append(metrics)
162
+
163
+ return pd.DataFrame(results)
164
+
165
+
166
+ # ══════════════════════════════════════════════════════════════════════════════
167
+ # Visualization
168
+ # ══════════════════════════════════════════════════════════════════════════════
169
+
170
+ def plot_alpha_search(results_df, output_path: Path):
171
+ """Plot metrics vs alpha."""
172
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
173
+
174
+ metrics = ["accuracy", "f1", "auc", "sensitivity"]
175
+ titles = ["Accuracy", "F1 Score", "AUC-ROC", "Sensitivity"]
176
+
177
+ for ax, metric, title in zip(axes.flat, metrics, titles):
178
+ ax.plot(results_df["alpha"], results_df[metric], marker="o", linewidth=2)
179
+ ax.set_xlabel("Alpha (HuBERT weight)", fontsize=12)
180
+ ax.set_ylabel(title, fontsize=12)
181
+ ax.set_title(f"{title} vs Alpha", fontsize=14)
182
+ ax.grid(True, alpha=0.3)
183
+
184
+ # Mark best alpha
185
+ best_idx = results_df[metric].idxmax()
186
+ best_alpha = results_df.loc[best_idx, "alpha"]
187
+ best_value = results_df.loc[best_idx, metric]
188
+ ax.axvline(best_alpha, color="red", linestyle="--", alpha=0.5)
189
+ ax.scatter([best_alpha], [best_value], color="red", s=100, zorder=5)
190
+ ax.text(
191
+ best_alpha,
192
+ best_value,
193
+ f"α={best_alpha:.2f}\n{best_value:.4f}",
194
+ ha="center",
195
+ va="bottom",
196
+ fontsize=10,
197
+ )
198
+
199
+ plt.tight_layout()
200
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
201
+ plt.close()
202
+
203
+
204
+ def plot_confusion_matrix(cm, alpha, output_path: Path):
205
+ """Plot confusion matrix for best alpha."""
206
+ plt.figure(figsize=(8, 6))
207
+ sns.heatmap(
208
+ cm,
209
+ annot=True,
210
+ fmt="d",
211
+ cmap="Blues",
212
+ xticklabels=["Healthy", "Dysarthric"],
213
+ yticklabels=["Healthy", "Dysarthric"],
214
+ )
215
+ plt.title(f"Confusion Matrix (α={alpha:.2f})", fontsize=14)
216
+ plt.ylabel("True Label", fontsize=12)
217
+ plt.xlabel("Predicted Label", fontsize=12)
218
+ plt.tight_layout()
219
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
220
+ plt.close()
221
+
222
+
223
+ # ══════════════════════════════════════════════════════════════════════════════
224
+ # Main
225
+ # ══════════════════════════════════════════════════════════════════════════════
226
+
227
+ def main():
228
+ logging.basicConfig(level=logging.INFO)
229
+ logger = logging.getLogger(__name__)
230
+
231
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
232
+ logger.info(f"Using device: {device}")
233
+
234
+ # ──────────────────────────────────────────────────────────────────────────
235
+ # Load Models
236
+ # ──────────────────────────────────────────────────────────────────────────
237
+ hubert_checkpoint = Path("models/hubert_salr_best.pt")
238
+ cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
239
+
240
+ if not hubert_checkpoint.exists():
241
+ logger.error(f"HuBERT checkpoint not found: {hubert_checkpoint}")
242
+ logger.error("Please train HuBERT-SALR first: python training/train_hubert_salr.py")
243
+ return
244
+
245
+ if not cnn_checkpoint.exists():
246
+ logger.error(f"CNN-BiLSTM checkpoint not found: {cnn_checkpoint}")
247
+ logger.error("Please train CNN-BiLSTM first: python training/train_cnn_bilstm.py")
248
+ return
249
+
250
+ logger.info("Loading HuBERT-SALR model...")
251
+ hubert_model = load_hubert_salr(hubert_checkpoint, device)
252
+
253
+ logger.info("Loading CNN-BiLSTM model...")
254
+ cnn_model = load_cnn_bilstm(cnn_checkpoint, device)
255
+
256
+ # ─────────────────────��────────────────────────────────────────────────────
257
+ # Load Validation Data
258
+ # ──────────────────────────────────────────────────────────────────────────
259
+ val_manifest = Path("data/manifests/val.csv")
260
+ val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
261
+ val_loader = DataLoader(
262
+ val_dataset,
263
+ batch_size=16,
264
+ shuffle=False,
265
+ num_workers=4,
266
+ pin_memory=True,
267
+ )
268
+
269
+ logger.info(f"Validation samples: {len(val_dataset)}")
270
+
271
+ # ──────────────────────────────────────────────────────────────────────────
272
+ # Grid Search
273
+ # ──────────────────────────────────────────────────────────────────────────
274
+ mlflow.set_experiment("ensemble_weight_optimization")
275
+
276
+ with mlflow.start_run():
277
+ logger.info("\nStarting grid search over alpha values...")
278
+
279
+ results_df = grid_search_alpha(
280
+ hubert_model,
281
+ cnn_model,
282
+ val_loader,
283
+ device,
284
+ alpha_range=(0.0, 1.0),
285
+ num_points=21,
286
+ )
287
+
288
+ # Find best alpha for each metric
289
+ best_alpha_auc = results_df.loc[results_df["auc"].idxmax(), "alpha"]
290
+ best_alpha_f1 = results_df.loc[results_df["f1"].idxmax(), "alpha"]
291
+ best_alpha_acc = results_df.loc[results_df["accuracy"].idxmax(), "alpha"]
292
+
293
+ logger.info("\n" + "=" * 80)
294
+ logger.info("GRID SEARCH RESULTS")
295
+ logger.info("=" * 80)
296
+ logger.info(f"Best alpha (AUC): {best_alpha_auc:.2f}")
297
+ logger.info(f"Best alpha (F1): {best_alpha_f1:.2f}")
298
+ logger.info(f"Best alpha (Accuracy): {best_alpha_acc:.2f}")
299
+ logger.info("=" * 80)
300
+
301
+ # Use AUC as primary metric
302
+ best_alpha = best_alpha_auc
303
+ best_row = results_df.loc[results_df["alpha"] == best_alpha].iloc[0]
304
+
305
+ logger.info(f"\nOptimal alpha: {best_alpha:.2f}")
306
+ logger.info(f" Accuracy: {best_row['accuracy']:.4f}")
307
+ logger.info(f" F1 Score: {best_row['f1']:.4f}")
308
+ logger.info(f" AUC: {best_row['auc']:.4f}")
309
+ logger.info(f" Sensitivity: {best_row['sensitivity']:.4f}")
310
+ logger.info(f" Specificity: {best_row['specificity']:.4f}")
311
+
312
+ # Log to MLflow
313
+ mlflow.log_params({
314
+ "num_alpha_points": 21,
315
+ "alpha_range_min": 0.0,
316
+ "alpha_range_max": 1.0,
317
+ })
318
+
319
+ mlflow.log_metrics({
320
+ "best_alpha": best_alpha,
321
+ "best_accuracy": best_row["accuracy"],
322
+ "best_f1": best_row["f1"],
323
+ "best_auc": best_row["auc"],
324
+ "best_sensitivity": best_row["sensitivity"],
325
+ "best_specificity": best_row["specificity"],
326
+ })
327
+
328
+ # Save results
329
+ output_dir = Path("reports/ensemble_optimization")
330
+ output_dir.mkdir(parents=True, exist_ok=True)
331
+
332
+ results_csv = output_dir / "alpha_search_results.csv"
333
+ results_df.to_csv(results_csv, index=False)
334
+ mlflow.log_artifact(str(results_csv))
335
+ logger.info(f"\n✓ Results saved to {results_csv}")
336
+
337
+ # Plot metrics vs alpha
338
+ plot_path = output_dir / "alpha_search_plot.png"
339
+ plot_alpha_search(results_df, plot_path)
340
+ mlflow.log_artifact(str(plot_path))
341
+ logger.info(f"✓ Plots saved to {plot_path}")
342
+
343
+ # Plot confusion matrix for best alpha
344
+ cm_path = output_dir / "confusion_matrix_best_alpha.png"
345
+ plot_confusion_matrix(best_row["confusion_matrix"], best_alpha, cm_path)
346
+ mlflow.log_artifact(str(cm_path))
347
+ logger.info(f"✓ Confusion matrix saved to {cm_path}")
348
+
349
+ # Save optimal config
350
+ optimal_config = {
351
+ "ensemble": {
352
+ "alpha": float(best_alpha),
353
+ "hubert_weight": float(best_alpha),
354
+ "cnn_bilstm_weight": float(1 - best_alpha),
355
+ "validation_metrics": {
356
+ "accuracy": float(best_row["accuracy"]),
357
+ "f1": float(best_row["f1"]),
358
+ "auc": float(best_row["auc"]),
359
+ "sensitivity": float(best_row["sensitivity"]),
360
+ "specificity": float(best_row["specificity"]),
361
+ },
362
+ }
363
+ }
364
+
365
+ config_path = output_dir / "optimal_ensemble_config.yaml"
366
+ with open(config_path, "w") as f:
367
+ yaml.dump(optimal_config, f, default_flow_style=False)
368
+ mlflow.log_artifact(str(config_path))
369
+ logger.info(f"✓ Optimal config saved to {config_path}")
370
+
371
+ logger.info("\n✓ Ensemble weight optimization complete!")
372
+ logger.info(f" Update configs/model_config.yaml with alpha={best_alpha:.2f}")
373
+
374
+
375
+ if __name__ == "__main__":
376
+ main()
training/train_hubert_fast.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fast fine-tuning script for HuBERT-SALR model.
4
+
5
+ Optimizations:
6
+ - Reduced dataset size (500-1000 samples)
7
+ - Fewer epochs (5 instead of 20)
8
+ - Simplified model architecture
9
+ - Uses MPS/GPU acceleration
10
+ - Faster feature extraction
11
+
12
+ Usage:
13
+ python training/train_hubert_fast.py
14
+ """
15
+
16
+ import sys
17
+ from pathlib import Path
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.optim as optim
23
+ from torch.utils.data import DataLoader, Subset
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+ import logging
27
+ import pandas as pd
28
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
29
+ from transformers import HubertModel
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # ══════════════════════════════════════════════════════════════════════════════
36
+ # Simplified HuBERT Model
37
+ # ══════════════════════════════════════════════════════════════════════════════
38
+
39
+ class SimplifiedHuBERTClassifier(nn.Module):
40
+ """Simplified HuBERT for faster training."""
41
+
42
+ def __init__(self, freeze_base=True):
43
+ super().__init__()
44
+
45
+ # Load pre-trained HuBERT (smaller version for speed)
46
+ logger.info("Loading HuBERT-base model...")
47
+ self.hubert = HubertModel.from_pretrained("facebook/hubert-base-ls960")
48
+
49
+ # Freeze base model for faster training
50
+ if freeze_base:
51
+ for param in self.hubert.parameters():
52
+ param.requires_grad = False
53
+ logger.info("✓ HuBERT base frozen (only training classifier)")
54
+
55
+ # Simple classifier head
56
+ hidden_size = self.hubert.config.hidden_size # 768 for base
57
+ self.classifier = nn.Sequential(
58
+ nn.Linear(hidden_size, 256),
59
+ nn.ReLU(),
60
+ nn.Dropout(0.3),
61
+ nn.Linear(256, 2), # Binary: healthy vs dysarthric
62
+ )
63
+
64
+ def forward(self, input_values):
65
+ # Extract features
66
+ with torch.no_grad() if self.training else torch.enable_grad():
67
+ outputs = self.hubert(input_values)
68
+
69
+ # Pool: mean across time dimension
70
+ hidden_states = outputs.last_hidden_state # (batch, time, hidden)
71
+ pooled = hidden_states.mean(dim=1) # (batch, hidden)
72
+
73
+ # Classify
74
+ logits = self.classifier(pooled)
75
+ return logits
76
+
77
+
78
+ # ══════════════════════════════════════════════════════════════════════════════
79
+ # Fast Dataset (No Heavy Feature Extraction)
80
+ # ══════════════════════════════════════════════════════════════════════════════
81
+
82
+ class FastDysarthriaDataset(torch.utils.data.Dataset):
83
+ """Simplified dataset for fast training."""
84
+
85
+ def __init__(self, manifest_path, max_duration=10.0, sample_rate=16000):
86
+ self.manifest = pd.read_csv(manifest_path)
87
+ self.max_duration = max_duration
88
+ self.sample_rate = sample_rate
89
+ self.max_length = int(max_duration * sample_rate)
90
+
91
+ # Filter valid files
92
+ self.manifest = self.manifest[
93
+ (self.manifest['duration'] >= 5.0) & # Min duration
94
+ (self.manifest['duration'] <= max_duration) # Max duration
95
+ ].reset_index(drop=True)
96
+
97
+ logger.info(f"Dataset: {len(self.manifest)} samples (filtered for 5-10s duration)")
98
+
99
+ def __len__(self):
100
+ return len(self.manifest)
101
+
102
+ def __getitem__(self, idx):
103
+ row = self.manifest.iloc[idx]
104
+
105
+ # Load audio
106
+ import librosa
107
+ waveform, sr = librosa.load(row['file_path'], sr=self.sample_rate)
108
+
109
+ # Pad or truncate to fixed length
110
+ if len(waveform) > self.max_length:
111
+ waveform = waveform[:self.max_length]
112
+ else:
113
+ waveform = np.pad(waveform, (0, self.max_length - len(waveform)))
114
+
115
+ return {
116
+ 'waveform': torch.FloatTensor(waveform),
117
+ 'label': int(row['label']),
118
+ }
119
+
120
+
121
+ # ══════════════════════════════════════════════════════════════════════════════
122
+ # Training Functions
123
+ # ══════════════════════════════════════════════════════════════════════════════
124
+
125
+ def train_epoch(model, dataloader, optimizer, criterion, device):
126
+ """Train for one epoch."""
127
+ model.train()
128
+ total_loss = 0
129
+ all_preds = []
130
+ all_labels = []
131
+
132
+ for batch in tqdm(dataloader, desc="Training"):
133
+ waveform = batch["waveform"].to(device)
134
+ labels = batch["label"].to(device)
135
+
136
+ optimizer.zero_grad()
137
+ logits = model(waveform)
138
+ loss = criterion(logits, labels)
139
+
140
+ loss.backward()
141
+ optimizer.step()
142
+
143
+ total_loss += loss.item()
144
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
145
+ all_preds.extend(preds)
146
+ all_labels.extend(labels.cpu().numpy())
147
+
148
+ avg_loss = total_loss / len(dataloader)
149
+ accuracy = accuracy_score(all_labels, all_preds)
150
+ f1 = f1_score(all_labels, all_preds, average="binary")
151
+
152
+ return avg_loss, accuracy, f1
153
+
154
+
155
+ def validate(model, dataloader, criterion, device):
156
+ """Validate the model."""
157
+ model.eval()
158
+ total_loss = 0
159
+ all_preds = []
160
+ all_probs = []
161
+ all_labels = []
162
+
163
+ with torch.no_grad():
164
+ for batch in tqdm(dataloader, desc="Validating"):
165
+ waveform = batch["waveform"].to(device)
166
+ labels = batch["label"].to(device)
167
+
168
+ logits = model(waveform)
169
+ loss = criterion(logits, labels)
170
+
171
+ total_loss += loss.item()
172
+ probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
173
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
174
+
175
+ all_preds.extend(preds)
176
+ all_probs.extend(probs)
177
+ all_labels.extend(labels.cpu().numpy())
178
+
179
+ avg_loss = total_loss / len(dataloader)
180
+ accuracy = accuracy_score(all_labels, all_preds)
181
+ f1 = f1_score(all_labels, all_preds, average="binary")
182
+ auc = roc_auc_score(all_labels, all_probs)
183
+
184
+ return avg_loss, accuracy, f1, auc
185
+
186
+
187
+ # ══════════════════════════════════════════════════════════════════════════════
188
+ # Main Training
189
+ # ══════════════════════════════════════════════════════════════════════════════
190
+
191
+ def main():
192
+ # Device selection
193
+ if torch.cuda.is_available():
194
+ device = torch.device("cuda")
195
+ elif torch.backends.mps.is_available():
196
+ device = torch.device("mps")
197
+ else:
198
+ device = torch.device("cpu")
199
+
200
+ logger.info(f"🚀 Using device: {device}")
201
+
202
+ # Load datasets
203
+ train_manifest = Path("data/manifests/train.csv")
204
+ val_manifest = Path("data/manifests/val.csv")
205
+
206
+ train_dataset = FastDysarthriaDataset(train_manifest, max_duration=10.0)
207
+ val_dataset = FastDysarthriaDataset(val_manifest, max_duration=10.0)
208
+
209
+ # Use subset for faster training
210
+ MAX_TRAIN_SAMPLES = 500 # Reduced from 3000
211
+ MAX_VAL_SAMPLES = 100 # Reduced from 647
212
+
213
+ if len(train_dataset) > MAX_TRAIN_SAMPLES:
214
+ indices = np.random.choice(len(train_dataset), MAX_TRAIN_SAMPLES, replace=False)
215
+ train_dataset = Subset(train_dataset, indices)
216
+ logger.info(f"✂️ Using subset: {MAX_TRAIN_SAMPLES} training samples")
217
+
218
+ if len(val_dataset) > MAX_VAL_SAMPLES:
219
+ indices = np.random.choice(len(val_dataset), MAX_VAL_SAMPLES, replace=False)
220
+ val_dataset = Subset(val_dataset, indices)
221
+ logger.info(f"✂️ Using subset: {MAX_VAL_SAMPLES} validation samples")
222
+
223
+ # Data loaders
224
+ train_loader = DataLoader(
225
+ train_dataset,
226
+ batch_size=4, # Small batch for speed
227
+ shuffle=True,
228
+ num_workers=0, # Avoid multiprocessing issues
229
+ )
230
+
231
+ val_loader = DataLoader(
232
+ val_dataset,
233
+ batch_size=4,
234
+ shuffle=False,
235
+ num_workers=0,
236
+ )
237
+
238
+ # Model
239
+ model = SimplifiedHuBERTClassifier(freeze_base=True).to(device)
240
+ logger.info(f"✓ Model loaded on {device}")
241
+
242
+ # Optimizer and loss
243
+ optimizer = optim.AdamW(model.classifier.parameters(), lr=1e-3) # Higher LR for frozen base
244
+ criterion = nn.CrossEntropyLoss()
245
+
246
+ # Training loop
247
+ NUM_EPOCHS = 5 # Reduced from 20
248
+ best_val_auc = 0
249
+ best_model_path = Path("models/hubert_fast_best.pt")
250
+ best_model_path.parent.mkdir(parents=True, exist_ok=True)
251
+
252
+ logger.info(f"\n{'='*80}")
253
+ logger.info(f" FAST TRAINING - {NUM_EPOCHS} epochs")
254
+ logger.info(f"{'='*80}\n")
255
+
256
+ for epoch in range(1, NUM_EPOCHS + 1):
257
+ logger.info(f"\nEpoch {epoch}/{NUM_EPOCHS}")
258
+ logger.info("-" * 40)
259
+
260
+ # Train
261
+ train_loss, train_acc, train_f1 = train_epoch(
262
+ model, train_loader, optimizer, criterion, device
263
+ )
264
+
265
+ # Validate
266
+ val_loss, val_acc, val_f1, val_auc = validate(
267
+ model, val_loader, criterion, device
268
+ )
269
+
270
+ # Log
271
+ logger.info(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.4f}, F1={train_f1:.4f}")
272
+ logger.info(f"Val: Loss={val_loss:.4f}, Acc={val_acc:.4f}, F1={val_f1:.4f}, AUC={val_auc:.4f}")
273
+
274
+ # Save best model
275
+ if val_auc > best_val_auc:
276
+ best_val_auc = val_auc
277
+ torch.save({
278
+ 'epoch': epoch,
279
+ 'model_state_dict': model.state_dict(),
280
+ 'optimizer_state_dict': optimizer.state_dict(),
281
+ 'val_auc': val_auc,
282
+ }, best_model_path)
283
+ logger.info(f"✓ New best model saved (AUC: {val_auc:.4f})")
284
+
285
+ logger.info(f"\n{'='*80}")
286
+ logger.info(f" ✓ TRAINING COMPLETE!")
287
+ logger.info(f"{'='*80}")
288
+ logger.info(f"Best validation AUC: {best_val_auc:.4f}")
289
+ logger.info(f"Model saved to: {best_model_path}")
290
+ logger.info(f"\nNext steps:")
291
+ logger.info(f" 1. Test the model on test set")
292
+ logger.info(f" 2. Update model_registry.py to use this checkpoint")
293
+ logger.info(f" 3. Run inference on new audio files")
294
+
295
+
296
+ if __name__ == "__main__":
297
+ main()
training/train_hubert_salr.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train HuBERT-SALR model for dysarthria detection."""
3
+
4
+ import logging
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from pathlib import Path
9
+ import mlflow
10
+ import yaml
11
+
12
+ from training.dataset import DysarthriaDataset, collate_fn
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class HuBERTSALRModel(nn.Module):
19
+ """HuBERT with SALR head for dysarthria detection."""
20
+
21
+ def __init__(self, hubert_checkpoint="facebook/hubert-large-ll60k"):
22
+ super().__init__()
23
+
24
+ from transformers import HubertModel
25
+
26
+ # Load pretrained HuBERT
27
+ self.hubert = HubertModel.from_pretrained(hubert_checkpoint)
28
+
29
+ # Freeze feature extractor (optional)
30
+ for param in self.hubert.feature_extractor.parameters():
31
+ param.requires_grad = False
32
+
33
+ # Layer-weighted pooling (learnable weights for 24 layers)
34
+ self.layer_weights = nn.Parameter(torch.ones(24) / 24)
35
+
36
+ # SALR head
37
+ self.classifier = nn.Sequential(
38
+ nn.Linear(1024, 256),
39
+ nn.ReLU(),
40
+ nn.Dropout(0.3),
41
+ nn.Linear(256, 2), # Binary classification
42
+ )
43
+
44
+ self.embedder = nn.Sequential(
45
+ nn.Linear(1024, 256),
46
+ nn.ReLU(),
47
+ nn.Dropout(0.3),
48
+ nn.Linear(256, 128), # Embedding for triplet loss
49
+ )
50
+
51
+ def forward(self, waveform):
52
+ """Forward pass."""
53
+ # HuBERT encoding
54
+ outputs = self.hubert(waveform, output_hidden_states=True)
55
+ hidden_states = outputs.hidden_states # (batch, seq_len, hidden_size) × 24 layers
56
+
57
+ # Layer-weighted pooling
58
+ weighted_hidden = torch.stack(
59
+ [self.layer_weights[i] * hidden_states[i] for i in range(24)],
60
+ dim=0
61
+ ).sum(dim=0) # (batch, seq_len, 1024)
62
+
63
+ # Global average pooling
64
+ pooled = weighted_hidden.mean(dim=1) # (batch, 1024)
65
+
66
+ # Classification logits
67
+ logits = self.classifier(pooled)
68
+
69
+ # Embeddings for triplet loss
70
+ embeddings = self.embedder(pooled)
71
+ embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
72
+
73
+ return logits, embeddings
74
+
75
+
76
+ def train_hubert_salr(
77
+ train_manifest="data/manifests/train_manifest.csv",
78
+ val_manifest="data/manifests/val_manifest.csv",
79
+ batch_size=8,
80
+ num_epochs=50,
81
+ learning_rate=1e-4,
82
+ device="cuda",
83
+ ):
84
+ """
85
+ Train HuBERT-SALR model.
86
+
87
+ Args:
88
+ train_manifest: Path to training manifest
89
+ val_manifest: Path to validation manifest
90
+ batch_size: Batch size
91
+ num_epochs: Number of epochs
92
+ learning_rate: Learning rate
93
+ device: Device (cuda/cpu)
94
+ """
95
+ # Set device
96
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
97
+ logger.info(f"Using device: {device}")
98
+
99
+ # Initialize MLflow
100
+ mlflow.set_experiment("dysarthria_hubert_salr")
101
+
102
+ with mlflow.start_run():
103
+ # Log parameters
104
+ mlflow.log_params({
105
+ "model": "HuBERT-SALR",
106
+ "batch_size": batch_size,
107
+ "num_epochs": num_epochs,
108
+ "learning_rate": learning_rate,
109
+ })
110
+
111
+ # Create datasets
112
+ train_dataset = DysarthriaDataset(train_manifest, augment=True)
113
+ val_dataset = DysarthriaDataset(val_manifest, augment=False)
114
+
115
+ train_loader = DataLoader(
116
+ train_dataset,
117
+ batch_size=batch_size,
118
+ shuffle=True,
119
+ num_workers=0, # Disabled for compatibility
120
+ collate_fn=collate_fn,
121
+ )
122
+
123
+ val_loader = DataLoader(
124
+ val_dataset,
125
+ batch_size=batch_size,
126
+ shuffle=False,
127
+ num_workers=0, # Disabled for compatibility
128
+ collate_fn=collate_fn,
129
+ )
130
+
131
+ # Initialize model
132
+ model = HuBERTSALRModel().to(device)
133
+
134
+ # Optimizer
135
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
136
+
137
+ # Losses
138
+ ce_loss_fn = nn.CrossEntropyLoss()
139
+ triplet_loss_fn = nn.TripletMarginLoss(margin=1.0)
140
+
141
+ # Training loop
142
+ best_val_loss = float("inf")
143
+
144
+ for epoch in range(num_epochs):
145
+ # Training
146
+ model.train()
147
+ train_loss = 0.0
148
+
149
+ for batch in train_loader:
150
+ waveform = batch["waveform"].to(device)
151
+ labels = batch["label"].squeeze(1).to(device)
152
+
153
+ optimizer.zero_grad()
154
+
155
+ # Forward pass
156
+ logits, embeddings = model(waveform)
157
+
158
+ # Classification loss
159
+ ce_loss = ce_loss_fn(logits, labels)
160
+
161
+ # Triplet loss (simplified: use random triplets)
162
+ # In full implementation, use hard negative mining
163
+ triplet_loss = torch.tensor(0.0).to(device) # Placeholder
164
+
165
+ # Combined loss
166
+ loss = ce_loss + 0.5 * triplet_loss
167
+
168
+ # Backward pass
169
+ loss.backward()
170
+ optimizer.step()
171
+
172
+ train_loss += loss.item()
173
+
174
+ train_loss /= len(train_loader)
175
+
176
+ # Validation
177
+ model.eval()
178
+ val_loss = 0.0
179
+ correct = 0
180
+ total = 0
181
+
182
+ with torch.no_grad():
183
+ for batch in val_loader:
184
+ waveform = batch["waveform"].to(device)
185
+ labels = batch["label"].squeeze(1).to(device)
186
+
187
+ logits, _ = model(waveform)
188
+ loss = ce_loss_fn(logits, labels)
189
+
190
+ val_loss += loss.item()
191
+
192
+ preds = logits.argmax(dim=1)
193
+ correct += (preds == labels).sum().item()
194
+ total += labels.size(0)
195
+
196
+ val_loss /= len(val_loader)
197
+ val_acc = correct / total
198
+
199
+ # Log metrics
200
+ mlflow.log_metrics({
201
+ "train_loss": train_loss,
202
+ "val_loss": val_loss,
203
+ "val_accuracy": val_acc,
204
+ }, step=epoch)
205
+
206
+ logger.info(
207
+ f"Epoch {epoch+1}/{num_epochs}: "
208
+ f"train_loss={train_loss:.4f}, "
209
+ f"val_loss={val_loss:.4f}, "
210
+ f"val_acc={val_acc:.4f}"
211
+ )
212
+
213
+ # Save best model
214
+ if val_loss < best_val_loss:
215
+ best_val_loss = val_loss
216
+ checkpoint_path = Path("models/checkpoints/hubert_salr_best.pt")
217
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
218
+ torch.save(model.state_dict(), checkpoint_path)
219
+ mlflow.log_artifact(str(checkpoint_path))
220
+
221
+ logger.info("Training complete!")
222
+
223
+
224
+ if __name__ == "__main__":
225
+ train_hubert_salr()