Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- training/augmentation.py +60 -0
- training/calibrate.py +419 -0
- training/dataset.py +202 -0
- training/evaluate.py +444 -0
- training/train_cnn_bilstm.py +379 -0
- training/train_ensemble_weights.py +376 -0
- training/train_hubert_fast.py +297 -0
- training/train_hubert_salr.py +225 -0
training/augmentation.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio data augmentation for training."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import librosa
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AudioAugmenter:
|
| 8 |
+
"""Apply audio augmentations for data diversity."""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
time_stretch_range=(0.8, 1.2),
|
| 13 |
+
pitch_shift_range=(-2, 2),
|
| 14 |
+
noise_level_range=(0.005, 0.015),
|
| 15 |
+
apply_prob=0.5,
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Initialize augmenter.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
time_stretch_range: (min_rate, max_rate) for time stretching
|
| 22 |
+
pitch_shift_range: (min_steps, max_steps) for pitch shifting
|
| 23 |
+
noise_level_range: (min_level, max_level) for additive noise
|
| 24 |
+
apply_prob: Probability of applying each augmentation
|
| 25 |
+
"""
|
| 26 |
+
self.time_stretch_range = time_stretch_range
|
| 27 |
+
self.pitch_shift_range = pitch_shift_range
|
| 28 |
+
self.noise_level_range = noise_level_range
|
| 29 |
+
self.apply_prob = apply_prob
|
| 30 |
+
|
| 31 |
+
def augment(self, waveform: np.ndarray, sr: int) -> np.ndarray:
|
| 32 |
+
"""
|
| 33 |
+
Apply random augmentations.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
waveform: Audio waveform
|
| 37 |
+
sr: Sample rate
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Augmented waveform
|
| 41 |
+
"""
|
| 42 |
+
# Time stretching
|
| 43 |
+
if np.random.rand() < self.apply_prob:
|
| 44 |
+
rate = np.random.uniform(*self.time_stretch_range)
|
| 45 |
+
waveform = librosa.effects.time_stretch(waveform, rate=rate)
|
| 46 |
+
|
| 47 |
+
# Pitch shifting
|
| 48 |
+
if np.random.rand() < self.apply_prob:
|
| 49 |
+
n_steps = np.random.uniform(*self.pitch_shift_range)
|
| 50 |
+
waveform = librosa.effects.pitch_shift(waveform, sr=sr, n_steps=n_steps)
|
| 51 |
+
|
| 52 |
+
# Additive noise
|
| 53 |
+
if np.random.rand() < self.apply_prob:
|
| 54 |
+
noise_level = np.random.uniform(*self.noise_level_range)
|
| 55 |
+
noise = np.random.randn(len(waveform)) * noise_level
|
| 56 |
+
waveform = waveform + noise
|
| 57 |
+
|
| 58 |
+
# SpecAugment (applied at spectrogram level, not here)
|
| 59 |
+
|
| 60 |
+
return waveform
|
training/calibrate.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calibrate model probabilities using Platt scaling.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads the ensemble model
|
| 6 |
+
2. Collects predictions on a held-out calibration set
|
| 7 |
+
3. Fits Platt scaling parameters (a, b) via logistic regression
|
| 8 |
+
4. Evaluates calibration quality (ECE, reliability diagrams)
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python training/calibrate.py
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
import mlflow
|
| 22 |
+
import numpy as np
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
import yaml
|
| 25 |
+
import logging
|
| 26 |
+
from sklearn.linear_model import LogisticRegression
|
| 27 |
+
from sklearn.metrics import brier_score_loss, log_loss
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
|
| 30 |
+
from training.dataset import DysarthriaDataset
|
| 31 |
+
from training.train_hubert_salr import HuBERTSALRModel
|
| 32 |
+
from training.train_cnn_bilstm import CNNBiLSTMTransformer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 36 |
+
# Calibration Metrics
|
| 37 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
|
| 39 |
+
def expected_calibration_error(y_true, y_prob, n_bins=10):
|
| 40 |
+
"""
|
| 41 |
+
Compute Expected Calibration Error (ECE).
|
| 42 |
+
|
| 43 |
+
ECE measures the difference between predicted confidence and actual accuracy.
|
| 44 |
+
Lower ECE indicates better calibration.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
y_true: True labels (0 or 1)
|
| 48 |
+
y_prob: Predicted probabilities (0 to 1)
|
| 49 |
+
n_bins: Number of bins for binning predictions
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
ECE value
|
| 53 |
+
"""
|
| 54 |
+
bin_edges = np.linspace(0, 1, n_bins + 1)
|
| 55 |
+
bin_indices = np.digitize(y_prob, bin_edges[:-1]) - 1
|
| 56 |
+
bin_indices = np.clip(bin_indices, 0, n_bins - 1)
|
| 57 |
+
|
| 58 |
+
ece = 0.0
|
| 59 |
+
for i in range(n_bins):
|
| 60 |
+
mask = bin_indices == i
|
| 61 |
+
if mask.sum() > 0:
|
| 62 |
+
bin_acc = y_true[mask].mean()
|
| 63 |
+
bin_conf = y_prob[mask].mean()
|
| 64 |
+
bin_weight = mask.sum() / len(y_true)
|
| 65 |
+
ece += bin_weight * np.abs(bin_acc - bin_conf)
|
| 66 |
+
|
| 67 |
+
return ece
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def reliability_curve(y_true, y_prob, n_bins=10):
|
| 71 |
+
"""
|
| 72 |
+
Compute reliability curve data for plotting.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
bin_centers, bin_accuracies, bin_confidences, bin_counts
|
| 76 |
+
"""
|
| 77 |
+
bin_edges = np.linspace(0, 1, n_bins + 1)
|
| 78 |
+
bin_indices = np.digitize(y_prob, bin_edges[:-1]) - 1
|
| 79 |
+
bin_indices = np.clip(bin_indices, 0, n_bins - 1)
|
| 80 |
+
|
| 81 |
+
bin_centers = []
|
| 82 |
+
bin_accuracies = []
|
| 83 |
+
bin_confidences = []
|
| 84 |
+
bin_counts = []
|
| 85 |
+
|
| 86 |
+
for i in range(n_bins):
|
| 87 |
+
mask = bin_indices == i
|
| 88 |
+
if mask.sum() > 0:
|
| 89 |
+
bin_centers.append((bin_edges[i] + bin_edges[i + 1]) / 2)
|
| 90 |
+
bin_accuracies.append(y_true[mask].mean())
|
| 91 |
+
bin_confidences.append(y_prob[mask].mean())
|
| 92 |
+
bin_counts.append(mask.sum())
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
np.array(bin_centers),
|
| 96 |
+
np.array(bin_accuracies),
|
| 97 |
+
np.array(bin_confidences),
|
| 98 |
+
np.array(bin_counts),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 103 |
+
# Model Inference
|
| 104 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 105 |
+
|
| 106 |
+
def collect_predictions(hubert_model, cnn_model, dataloader, alpha, device):
|
| 107 |
+
"""
|
| 108 |
+
Collect raw logits and probabilities from ensemble.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
hubert_model: HuBERT-SALR model
|
| 112 |
+
cnn_model: CNN-BiLSTM model
|
| 113 |
+
dataloader: Data loader
|
| 114 |
+
alpha: Ensemble mixing weight
|
| 115 |
+
device: torch device
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
logits, probabilities, true labels (all numpy arrays)
|
| 119 |
+
"""
|
| 120 |
+
all_logits = []
|
| 121 |
+
all_probs = []
|
| 122 |
+
all_labels = []
|
| 123 |
+
|
| 124 |
+
hubert_model.eval()
|
| 125 |
+
cnn_model.eval()
|
| 126 |
+
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
for batch in tqdm(dataloader, desc="Collecting predictions"):
|
| 129 |
+
waveform = batch["waveform"].to(device)
|
| 130 |
+
spectrogram = batch["spectrogram"].to(device)
|
| 131 |
+
labels = batch["label"]
|
| 132 |
+
|
| 133 |
+
# Ensemble logits
|
| 134 |
+
hubert_logits = hubert_model(waveform)
|
| 135 |
+
cnn_logits = cnn_model(spectrogram)
|
| 136 |
+
ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
|
| 137 |
+
|
| 138 |
+
# Probabilities (uncalibrated)
|
| 139 |
+
probs = torch.softmax(ensemble_logits, dim=1)[:, 1]
|
| 140 |
+
|
| 141 |
+
all_logits.extend(ensemble_logits[:, 1].cpu().numpy())
|
| 142 |
+
all_probs.extend(probs.cpu().numpy())
|
| 143 |
+
all_labels.extend(labels.numpy())
|
| 144 |
+
|
| 145 |
+
return (
|
| 146 |
+
np.array(all_logits),
|
| 147 |
+
np.array(all_probs),
|
| 148 |
+
np.array(all_labels),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 153 |
+
# Platt Scaling
|
| 154 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 155 |
+
|
| 156 |
+
def fit_platt_scaling(logits, labels):
|
| 157 |
+
"""
|
| 158 |
+
Fit Platt scaling parameters.
|
| 159 |
+
|
| 160 |
+
Platt scaling fits:
|
| 161 |
+
calibrated_prob = sigmoid(a * logit + b)
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
logits: Raw model logits (n_samples,)
|
| 165 |
+
labels: True binary labels (n_samples,)
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
a, b parameters
|
| 169 |
+
"""
|
| 170 |
+
# Reshape for sklearn
|
| 171 |
+
X = logits.reshape(-1, 1)
|
| 172 |
+
y = labels
|
| 173 |
+
|
| 174 |
+
# Fit logistic regression (no regularization)
|
| 175 |
+
lr = LogisticRegression(penalty=None, solver="lbfgs", max_iter=1000)
|
| 176 |
+
lr.fit(X, y)
|
| 177 |
+
|
| 178 |
+
a = lr.coef_[0][0]
|
| 179 |
+
b = lr.intercept_[0]
|
| 180 |
+
|
| 181 |
+
return a, b
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def apply_platt_scaling(logits, a, b):
|
| 185 |
+
"""Apply Platt scaling to logits."""
|
| 186 |
+
z = a * logits + b
|
| 187 |
+
calibrated_probs = 1 / (1 + np.exp(-z))
|
| 188 |
+
return calibrated_probs
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 192 |
+
# Visualization
|
| 193 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 194 |
+
|
| 195 |
+
def plot_reliability_diagram(
|
| 196 |
+
y_true,
|
| 197 |
+
y_prob_uncal,
|
| 198 |
+
y_prob_cal,
|
| 199 |
+
output_path: Path,
|
| 200 |
+
):
|
| 201 |
+
"""Plot reliability diagram comparing uncalibrated vs calibrated."""
|
| 202 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 203 |
+
|
| 204 |
+
for ax, probs, title in zip(
|
| 205 |
+
axes,
|
| 206 |
+
[y_prob_uncal, y_prob_cal],
|
| 207 |
+
["Uncalibrated", "Calibrated"],
|
| 208 |
+
):
|
| 209 |
+
centers, accs, confs, counts = reliability_curve(y_true, probs, n_bins=10)
|
| 210 |
+
|
| 211 |
+
# Plot reliability curve
|
| 212 |
+
ax.plot([0, 1], [0, 1], "k--", label="Perfect calibration", linewidth=2)
|
| 213 |
+
ax.scatter(confs, accs, s=counts * 3, alpha=0.6, label="Model", zorder=5)
|
| 214 |
+
ax.plot(confs, accs, "o-", linewidth=2, markersize=8)
|
| 215 |
+
|
| 216 |
+
# Compute ECE
|
| 217 |
+
ece = expected_calibration_error(y_true, probs)
|
| 218 |
+
brier = brier_score_loss(y_true, probs)
|
| 219 |
+
|
| 220 |
+
ax.set_xlabel("Mean Predicted Probability", fontsize=12)
|
| 221 |
+
ax.set_ylabel("Fraction of Positives", fontsize=12)
|
| 222 |
+
ax.set_title(f"{title}\nECE: {ece:.4f}, Brier: {brier:.4f}", fontsize=14)
|
| 223 |
+
ax.legend(fontsize=10)
|
| 224 |
+
ax.grid(True, alpha=0.3)
|
| 225 |
+
ax.set_xlim([0, 1])
|
| 226 |
+
ax.set_ylim([0, 1])
|
| 227 |
+
|
| 228 |
+
plt.tight_layout()
|
| 229 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 230 |
+
plt.close()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def plot_histogram_comparison(
|
| 234 |
+
y_true,
|
| 235 |
+
y_prob_uncal,
|
| 236 |
+
y_prob_cal,
|
| 237 |
+
output_path: Path,
|
| 238 |
+
):
|
| 239 |
+
"""Plot histogram of predicted probabilities."""
|
| 240 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 241 |
+
|
| 242 |
+
# Split by true label
|
| 243 |
+
mask_positive = y_true == 1
|
| 244 |
+
mask_negative = y_true == 0
|
| 245 |
+
|
| 246 |
+
for i, (probs, title) in enumerate(
|
| 247 |
+
[(y_prob_uncal, "Uncalibrated"), (y_prob_cal, "Calibrated")]
|
| 248 |
+
):
|
| 249 |
+
# Positive class
|
| 250 |
+
axes[i, 0].hist(probs[mask_positive], bins=20, alpha=0.7, color="red", edgecolor="black")
|
| 251 |
+
axes[i, 0].set_xlabel("Predicted Probability", fontsize=12)
|
| 252 |
+
axes[i, 0].set_ylabel("Count", fontsize=12)
|
| 253 |
+
axes[i, 0].set_title(f"{title} - True Dysarthric", fontsize=14)
|
| 254 |
+
axes[i, 0].grid(True, alpha=0.3)
|
| 255 |
+
|
| 256 |
+
# Negative class
|
| 257 |
+
axes[i, 1].hist(probs[mask_negative], bins=20, alpha=0.7, color="blue", edgecolor="black")
|
| 258 |
+
axes[i, 1].set_xlabel("Predicted Probability", fontsize=12)
|
| 259 |
+
axes[i, 1].set_ylabel("Count", fontsize=12)
|
| 260 |
+
axes[i, 1].set_title(f"{title} - True Healthy", fontsize=14)
|
| 261 |
+
axes[i, 1].grid(True, alpha=0.3)
|
| 262 |
+
|
| 263 |
+
plt.tight_layout()
|
| 264 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 265 |
+
plt.close()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 269 |
+
# Main
|
| 270 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 271 |
+
|
| 272 |
+
def main():
|
| 273 |
+
logging.basicConfig(level=logging.INFO)
|
| 274 |
+
logger = logging.getLogger(__name__)
|
| 275 |
+
|
| 276 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 277 |
+
logger.info(f"Using device: {device}")
|
| 278 |
+
|
| 279 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 280 |
+
# Load Models
|
| 281 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 282 |
+
logger.info("Loading models...")
|
| 283 |
+
|
| 284 |
+
hubert_checkpoint = Path("models/hubert_salr_best.pt")
|
| 285 |
+
cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
|
| 286 |
+
|
| 287 |
+
hubert_model = HuBERTSALRModel()
|
| 288 |
+
hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"])
|
| 289 |
+
hubert_model.to(device)
|
| 290 |
+
|
| 291 |
+
cnn_model = CNNBiLSTMTransformer()
|
| 292 |
+
cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"])
|
| 293 |
+
cnn_model.to(device)
|
| 294 |
+
|
| 295 |
+
# Load optimal alpha
|
| 296 |
+
with open("configs/model_config.yaml") as f:
|
| 297 |
+
config = yaml.safe_load(f)
|
| 298 |
+
alpha = config.get("ensemble", {}).get("alpha", 0.6)
|
| 299 |
+
logger.info(f"Using ensemble alpha: {alpha}")
|
| 300 |
+
|
| 301 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 302 |
+
# Load Calibration Data (use validation set)
|
| 303 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 304 |
+
val_manifest = Path("data/manifests/val.csv")
|
| 305 |
+
val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
|
| 306 |
+
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
|
| 307 |
+
|
| 308 |
+
logger.info(f"Calibration samples: {len(val_dataset)}")
|
| 309 |
+
|
| 310 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 311 |
+
# Collect Predictions
|
| 312 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 313 |
+
logger.info("Collecting predictions...")
|
| 314 |
+
|
| 315 |
+
logits, probs_uncal, labels = collect_predictions(
|
| 316 |
+
hubert_model, cnn_model, val_loader, alpha, device
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 320 |
+
# Fit Platt Scaling
|
| 321 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 322 |
+
mlflow.set_experiment("model_calibration")
|
| 323 |
+
|
| 324 |
+
with mlflow.start_run():
|
| 325 |
+
logger.info("\nFitting Platt scaling...")
|
| 326 |
+
|
| 327 |
+
a, b = fit_platt_scaling(logits, labels)
|
| 328 |
+
logger.info(f"Platt parameters: a={a:.6f}, b={b:.6f}")
|
| 329 |
+
|
| 330 |
+
# Apply calibration
|
| 331 |
+
probs_cal = apply_platt_scaling(logits, a, b)
|
| 332 |
+
|
| 333 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 334 |
+
# Evaluate Calibration
|
| 335 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 336 |
+
ece_uncal = expected_calibration_error(labels, probs_uncal)
|
| 337 |
+
ece_cal = expected_calibration_error(labels, probs_cal)
|
| 338 |
+
|
| 339 |
+
brier_uncal = brier_score_loss(labels, probs_uncal)
|
| 340 |
+
brier_cal = brier_score_loss(labels, probs_cal)
|
| 341 |
+
|
| 342 |
+
logloss_uncal = log_loss(labels, probs_uncal)
|
| 343 |
+
logloss_cal = log_loss(labels, probs_cal)
|
| 344 |
+
|
| 345 |
+
logger.info("\n" + "=" * 80)
|
| 346 |
+
logger.info("CALIBRATION RESULTS")
|
| 347 |
+
logger.info("=" * 80)
|
| 348 |
+
logger.info(f"Expected Calibration Error (ECE):")
|
| 349 |
+
logger.info(f" Uncalibrated: {ece_uncal:.4f}")
|
| 350 |
+
logger.info(f" Calibrated: {ece_cal:.4f} ({'↓' if ece_cal < ece_uncal else '↑'} {abs(ece_cal - ece_uncal):.4f})")
|
| 351 |
+
logger.info(f"\nBrier Score:")
|
| 352 |
+
logger.info(f" Uncalibrated: {brier_uncal:.4f}")
|
| 353 |
+
logger.info(f" Calibrated: {brier_cal:.4f} ({'↓' if brier_cal < brier_uncal else '↑'} {abs(brier_cal - brier_uncal):.4f})")
|
| 354 |
+
logger.info(f"\nLog Loss:")
|
| 355 |
+
logger.info(f" Uncalibrated: {logloss_uncal:.4f}")
|
| 356 |
+
logger.info(f" Calibrated: {logloss_cal:.4f} ({'↓' if logloss_cal < logloss_uncal else '↑'} {abs(logloss_cal - logloss_uncal):.4f})")
|
| 357 |
+
logger.info("=" * 80)
|
| 358 |
+
|
| 359 |
+
# Log to MLflow
|
| 360 |
+
mlflow.log_params({
|
| 361 |
+
"platt_a": a,
|
| 362 |
+
"platt_b": b,
|
| 363 |
+
"calibration_samples": len(labels),
|
| 364 |
+
})
|
| 365 |
+
|
| 366 |
+
mlflow.log_metrics({
|
| 367 |
+
"ece_uncalibrated": ece_uncal,
|
| 368 |
+
"ece_calibrated": ece_cal,
|
| 369 |
+
"ece_improvement": ece_uncal - ece_cal,
|
| 370 |
+
"brier_uncalibrated": brier_uncal,
|
| 371 |
+
"brier_calibrated": brier_cal,
|
| 372 |
+
"logloss_uncalibrated": logloss_uncal,
|
| 373 |
+
"logloss_calibrated": logloss_cal,
|
| 374 |
+
})
|
| 375 |
+
|
| 376 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 377 |
+
# Save Results
|
| 378 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 379 |
+
output_dir = Path("reports/calibration")
|
| 380 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 381 |
+
|
| 382 |
+
# Save Platt parameters
|
| 383 |
+
calibration_config = {
|
| 384 |
+
"platt_scaling": {
|
| 385 |
+
"a": float(a),
|
| 386 |
+
"b": float(b),
|
| 387 |
+
"ece_uncalibrated": float(ece_uncal),
|
| 388 |
+
"ece_calibrated": float(ece_cal),
|
| 389 |
+
"brier_uncalibrated": float(brier_uncal),
|
| 390 |
+
"brier_calibrated": float(brier_cal),
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
config_path = output_dir / "calibration_params.yaml"
|
| 395 |
+
with open(config_path, "w") as f:
|
| 396 |
+
yaml.dump(calibration_config, f, default_flow_style=False)
|
| 397 |
+
mlflow.log_artifact(str(config_path))
|
| 398 |
+
logger.info(f"\n✓ Calibration parameters saved to {config_path}")
|
| 399 |
+
|
| 400 |
+
# Plot reliability diagram
|
| 401 |
+
reliability_path = output_dir / "reliability_diagram.png"
|
| 402 |
+
plot_reliability_diagram(labels, probs_uncal, probs_cal, reliability_path)
|
| 403 |
+
mlflow.log_artifact(str(reliability_path))
|
| 404 |
+
logger.info(f"✓ Reliability diagram saved to {reliability_path}")
|
| 405 |
+
|
| 406 |
+
# Plot histogram comparison
|
| 407 |
+
hist_path = output_dir / "probability_histograms.png"
|
| 408 |
+
plot_histogram_comparison(labels, probs_uncal, probs_cal, hist_path)
|
| 409 |
+
mlflow.log_artifact(str(hist_path))
|
| 410 |
+
logger.info(f"✓ Probability histograms saved to {hist_path}")
|
| 411 |
+
|
| 412 |
+
logger.info("\n✓ Calibration complete!")
|
| 413 |
+
logger.info(f" Update configs/model_config.yaml with Platt parameters:")
|
| 414 |
+
logger.info(f" a: {a:.6f}")
|
| 415 |
+
logger.info(f" b: {b:.6f}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
main()
|
training/dataset.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch Dataset for dysarthria detection."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from src.ingestion.audio_loader import AudioLoader
|
| 11 |
+
from src.ingestion.preprocessor import AudioPreprocessor
|
| 12 |
+
from src.features.mfcc_extractor import MFCCExtractor
|
| 13 |
+
from src.features.prosodic_extractor import ProsodicExtractor
|
| 14 |
+
from src.features.formant_extractor import FormantExtractor
|
| 15 |
+
from src.features.egemaps_extractor import EGeMAPSExtractor
|
| 16 |
+
from src.features.spectrogram_builder import SpectrogramBuilder
|
| 17 |
+
from src.features.feature_fusion import FeatureFusion
|
| 18 |
+
from src.features.schemas import FeatureBundle
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DysarthriaDataset(Dataset):
|
| 24 |
+
"""Dataset for dysarthria detection with on-the-fly feature extraction."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
manifest_path: str | Path,
|
| 29 |
+
augment: bool = False,
|
| 30 |
+
cache_features: bool = False,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize dataset.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
manifest_path: Path to CSV manifest (filepath, label, speaker_id, duration)
|
| 37 |
+
augment: Apply data augmentation
|
| 38 |
+
cache_features: Cache extracted features in memory
|
| 39 |
+
"""
|
| 40 |
+
self.manifest = pd.read_csv(manifest_path)
|
| 41 |
+
self.augment = augment
|
| 42 |
+
self.cache_features = cache_features
|
| 43 |
+
self.feature_cache = {} if cache_features else None
|
| 44 |
+
|
| 45 |
+
# Initialize components
|
| 46 |
+
self.audio_loader = AudioLoader()
|
| 47 |
+
self.preprocessor = AudioPreprocessor(target_sr=16000)
|
| 48 |
+
self.mfcc_extractor = MFCCExtractor()
|
| 49 |
+
self.prosodic_extractor = ProsodicExtractor()
|
| 50 |
+
self.formant_extractor = FormantExtractor()
|
| 51 |
+
self.egemaps_extractor = EGeMAPSExtractor()
|
| 52 |
+
self.spectrogram_builder = SpectrogramBuilder()
|
| 53 |
+
self.feature_fusion = FeatureFusion()
|
| 54 |
+
|
| 55 |
+
logger.info(f"Dataset initialized: {len(self)} samples")
|
| 56 |
+
|
| 57 |
+
def __len__(self) -> int:
|
| 58 |
+
return len(self.manifest)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx: int) -> dict:
|
| 61 |
+
"""
|
| 62 |
+
Get item by index.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
dict with keys:
|
| 66 |
+
- waveform: torch.Tensor (samples,)
|
| 67 |
+
- spectrogram: torch.Tensor (2, freq, time)
|
| 68 |
+
- acoustic_features: torch.Tensor (n_features,)
|
| 69 |
+
- label: torch.Tensor (1,)
|
| 70 |
+
- speaker_id: str
|
| 71 |
+
"""
|
| 72 |
+
# Check cache
|
| 73 |
+
if self.cache_features and idx in self.feature_cache:
|
| 74 |
+
return self.feature_cache[idx]
|
| 75 |
+
|
| 76 |
+
# Load sample info
|
| 77 |
+
row = self.manifest.iloc[idx]
|
| 78 |
+
audio_path = row["file_path"] # Changed from "filepath" to "file_path"
|
| 79 |
+
label = int(row["label"])
|
| 80 |
+
speaker_id = row["speaker_id"]
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
# Load and preprocess audio
|
| 84 |
+
audio_input, waveform = self.audio_loader.load(audio_path)
|
| 85 |
+
preprocessed = self.preprocessor.process(
|
| 86 |
+
waveform,
|
| 87 |
+
sr=audio_input.sample_rate, # Use original SR
|
| 88 |
+
original_duration=row["duration"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
waveform = preprocessed.waveform
|
| 92 |
+
sr = preprocessed.sample_rate
|
| 93 |
+
|
| 94 |
+
# Apply augmentation if training
|
| 95 |
+
if self.augment:
|
| 96 |
+
waveform = self._apply_augmentation(waveform, sr)
|
| 97 |
+
|
| 98 |
+
# Extract features
|
| 99 |
+
mfcc = self.mfcc_extractor.extract(waveform, sr)
|
| 100 |
+
prosody = self.prosodic_extractor.extract(waveform, sr)
|
| 101 |
+
formants = self.formant_extractor.extract(waveform, sr)
|
| 102 |
+
egemaps = self.egemaps_extractor.extract(waveform, sr)
|
| 103 |
+
spectrogram = self.spectrogram_builder.build(waveform, sr)
|
| 104 |
+
|
| 105 |
+
# Create feature bundle
|
| 106 |
+
feature_bundle = FeatureBundle(
|
| 107 |
+
waveform=waveform,
|
| 108 |
+
sample_rate=sr,
|
| 109 |
+
duration_sec=preprocessed.duration_sec,
|
| 110 |
+
mfcc=mfcc,
|
| 111 |
+
prosody=prosody,
|
| 112 |
+
formants=formants,
|
| 113 |
+
egemaps=egemaps,
|
| 114 |
+
spectrogram=spectrogram,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Fuse acoustic features
|
| 118 |
+
feature_bundle = self.feature_fusion.fuse(feature_bundle)
|
| 119 |
+
|
| 120 |
+
# Convert to tensors
|
| 121 |
+
item = {
|
| 122 |
+
"waveform": torch.from_numpy(waveform).float(),
|
| 123 |
+
"spectrogram": torch.from_numpy(spectrogram.stacked).float(),
|
| 124 |
+
"acoustic_features": torch.from_numpy(feature_bundle.fused_acoustic).float(),
|
| 125 |
+
"label": torch.tensor([label], dtype=torch.long),
|
| 126 |
+
"speaker_id": speaker_id,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Cache if enabled
|
| 130 |
+
if self.cache_features:
|
| 131 |
+
self.feature_cache[idx] = item
|
| 132 |
+
|
| 133 |
+
return item
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Failed to load sample {idx} ({audio_path}): {e}")
|
| 137 |
+
# Return a dummy sample
|
| 138 |
+
return self._get_dummy_item(label, speaker_id)
|
| 139 |
+
|
| 140 |
+
def _apply_augmentation(self, waveform: np.ndarray, sr: int) -> np.ndarray:
|
| 141 |
+
"""Apply data augmentation."""
|
| 142 |
+
from training.augmentation import AudioAugmenter
|
| 143 |
+
|
| 144 |
+
augmenter = AudioAugmenter()
|
| 145 |
+
return augmenter.augment(waveform, sr)
|
| 146 |
+
|
| 147 |
+
def _get_dummy_item(self, label: int, speaker_id: str) -> dict:
|
| 148 |
+
"""Return a dummy item when loading fails."""
|
| 149 |
+
return {
|
| 150 |
+
"waveform": torch.zeros(16000 * 10), # 10 seconds of silence
|
| 151 |
+
"spectrogram": torch.zeros(2, 128, 313),
|
| 152 |
+
"acoustic_features": torch.zeros(145),
|
| 153 |
+
"label": torch.tensor([label], dtype=torch.long),
|
| 154 |
+
"speaker_id": speaker_id,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def collate_fn(batch: list[dict]) -> dict:
|
| 159 |
+
"""
|
| 160 |
+
Collate function for DataLoader.
|
| 161 |
+
|
| 162 |
+
Handles variable-length sequences by padding.
|
| 163 |
+
"""
|
| 164 |
+
# Find max lengths
|
| 165 |
+
max_waveform_len = max(item["waveform"].shape[0] for item in batch)
|
| 166 |
+
max_time_frames = max(item["spectrogram"].shape[2] for item in batch)
|
| 167 |
+
|
| 168 |
+
# Pad sequences
|
| 169 |
+
waveforms = []
|
| 170 |
+
spectrograms = []
|
| 171 |
+
acoustic_features = []
|
| 172 |
+
labels = []
|
| 173 |
+
speaker_ids = []
|
| 174 |
+
|
| 175 |
+
for item in batch:
|
| 176 |
+
# Pad waveform
|
| 177 |
+
waveform = item["waveform"]
|
| 178 |
+
if waveform.shape[0] < max_waveform_len:
|
| 179 |
+
waveform = torch.nn.functional.pad(
|
| 180 |
+
waveform, (0, max_waveform_len - waveform.shape[0])
|
| 181 |
+
)
|
| 182 |
+
waveforms.append(waveform)
|
| 183 |
+
|
| 184 |
+
# Pad spectrogram
|
| 185 |
+
spec = item["spectrogram"]
|
| 186 |
+
if spec.shape[2] < max_time_frames:
|
| 187 |
+
spec = torch.nn.functional.pad(
|
| 188 |
+
spec, (0, max_time_frames - spec.shape[2])
|
| 189 |
+
)
|
| 190 |
+
spectrograms.append(spec)
|
| 191 |
+
|
| 192 |
+
acoustic_features.append(item["acoustic_features"])
|
| 193 |
+
labels.append(item["label"])
|
| 194 |
+
speaker_ids.append(item["speaker_id"])
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"waveform": torch.stack(waveforms),
|
| 198 |
+
"spectrogram": torch.stack(spectrograms),
|
| 199 |
+
"acoustic_features": torch.stack(acoustic_features),
|
| 200 |
+
"label": torch.stack(labels),
|
| 201 |
+
"speaker_id": speaker_ids,
|
| 202 |
+
}
|
training/evaluate.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive model evaluation on test set.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads trained ensemble model with calibration
|
| 6 |
+
2. Evaluates on held-out test set
|
| 7 |
+
3. Computes classification metrics (accuracy, F1, AUC, sensitivity, specificity)
|
| 8 |
+
4. Generates confusion matrix, ROC curve, PR curve
|
| 9 |
+
5. Performs error analysis
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python training/evaluate.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
import mlflow
|
| 23 |
+
import numpy as np
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
import yaml
|
| 26 |
+
import pandas as pd
|
| 27 |
+
import logging
|
| 28 |
+
from sklearn.metrics import (
|
| 29 |
+
accuracy_score,
|
| 30 |
+
f1_score,
|
| 31 |
+
roc_auc_score,
|
| 32 |
+
confusion_matrix,
|
| 33 |
+
classification_report,
|
| 34 |
+
roc_curve,
|
| 35 |
+
precision_recall_curve,
|
| 36 |
+
average_precision_score,
|
| 37 |
+
)
|
| 38 |
+
import matplotlib.pyplot as plt
|
| 39 |
+
import seaborn as sns
|
| 40 |
+
|
| 41 |
+
from training.dataset import DysarthriaDataset
|
| 42 |
+
from training.train_hubert_salr import HuBERTSALRModel
|
| 43 |
+
from training.train_cnn_bilstm import CNNBiLSTMTransformer
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 47 |
+
# Model Inference
|
| 48 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 49 |
+
|
| 50 |
+
def evaluate_model(hubert_model, cnn_model, dataloader, alpha, platt_a, platt_b, device):
|
| 51 |
+
"""
|
| 52 |
+
Evaluate calibrated ensemble on test set.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
predictions, probabilities, labels, file_paths
|
| 56 |
+
"""
|
| 57 |
+
all_preds = []
|
| 58 |
+
all_probs = []
|
| 59 |
+
all_labels = []
|
| 60 |
+
all_files = []
|
| 61 |
+
|
| 62 |
+
hubert_model.eval()
|
| 63 |
+
cnn_model.eval()
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 67 |
+
waveform = batch["waveform"].to(device)
|
| 68 |
+
spectrogram = batch["spectrogram"].to(device)
|
| 69 |
+
labels = batch["label"]
|
| 70 |
+
file_paths = batch.get("file_path", [""] * len(labels))
|
| 71 |
+
|
| 72 |
+
# Ensemble logits
|
| 73 |
+
hubert_logits = hubert_model(waveform)
|
| 74 |
+
cnn_logits = cnn_model(spectrogram)
|
| 75 |
+
ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
|
| 76 |
+
|
| 77 |
+
# Apply Platt scaling
|
| 78 |
+
raw_logits = ensemble_logits[:, 1].cpu().numpy()
|
| 79 |
+
z = platt_a * raw_logits + platt_b
|
| 80 |
+
calibrated_probs = 1 / (1 + np.exp(-z))
|
| 81 |
+
|
| 82 |
+
# Predictions
|
| 83 |
+
preds = (calibrated_probs > 0.5).astype(int)
|
| 84 |
+
|
| 85 |
+
all_preds.extend(preds)
|
| 86 |
+
all_probs.extend(calibrated_probs)
|
| 87 |
+
all_labels.extend(labels.numpy())
|
| 88 |
+
all_files.extend(file_paths)
|
| 89 |
+
|
| 90 |
+
return (
|
| 91 |
+
np.array(all_preds),
|
| 92 |
+
np.array(all_probs),
|
| 93 |
+
np.array(all_labels),
|
| 94 |
+
all_files,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 99 |
+
# Metrics Computation
|
| 100 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 101 |
+
|
| 102 |
+
def compute_metrics(y_true, y_pred, y_prob):
|
| 103 |
+
"""Compute comprehensive classification metrics."""
|
| 104 |
+
# Basic metrics
|
| 105 |
+
accuracy = accuracy_score(y_true, y_pred)
|
| 106 |
+
f1 = f1_score(y_true, y_pred, average="binary")
|
| 107 |
+
auc = roc_auc_score(y_true, y_prob)
|
| 108 |
+
ap = average_precision_score(y_true, y_prob)
|
| 109 |
+
|
| 110 |
+
# Confusion matrix
|
| 111 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 112 |
+
tn, fp, fn, tp = cm.ravel()
|
| 113 |
+
|
| 114 |
+
# Sensitivity and specificity
|
| 115 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 116 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 117 |
+
|
| 118 |
+
# Positive and negative predictive value
|
| 119 |
+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 120 |
+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
"accuracy": accuracy,
|
| 124 |
+
"f1": f1,
|
| 125 |
+
"auc": auc,
|
| 126 |
+
"average_precision": ap,
|
| 127 |
+
"sensitivity": sensitivity,
|
| 128 |
+
"specificity": specificity,
|
| 129 |
+
"ppv": ppv,
|
| 130 |
+
"npv": npv,
|
| 131 |
+
"tp": int(tp),
|
| 132 |
+
"tn": int(tn),
|
| 133 |
+
"fp": int(fp),
|
| 134 |
+
"fn": int(fn),
|
| 135 |
+
"confusion_matrix": cm,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 140 |
+
# Visualization
|
| 141 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 142 |
+
|
| 143 |
+
def plot_confusion_matrix(cm, output_path: Path):
|
| 144 |
+
"""Plot confusion matrix."""
|
| 145 |
+
plt.figure(figsize=(8, 6))
|
| 146 |
+
sns.heatmap(
|
| 147 |
+
cm,
|
| 148 |
+
annot=True,
|
| 149 |
+
fmt="d",
|
| 150 |
+
cmap="Blues",
|
| 151 |
+
xticklabels=["Healthy", "Dysarthric"],
|
| 152 |
+
yticklabels=["Healthy", "Dysarthric"],
|
| 153 |
+
cbar_kws={"label": "Count"},
|
| 154 |
+
)
|
| 155 |
+
plt.title("Confusion Matrix - Test Set", fontsize=16, fontweight="bold")
|
| 156 |
+
plt.ylabel("True Label", fontsize=14)
|
| 157 |
+
plt.xlabel("Predicted Label", fontsize=14)
|
| 158 |
+
plt.tight_layout()
|
| 159 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 160 |
+
plt.close()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def plot_roc_curve(y_true, y_prob, auc_score, output_path: Path):
|
| 164 |
+
"""Plot ROC curve."""
|
| 165 |
+
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
|
| 166 |
+
|
| 167 |
+
plt.figure(figsize=(8, 6))
|
| 168 |
+
plt.plot(fpr, tpr, linewidth=2, label=f"Model (AUC = {auc_score:.4f})")
|
| 169 |
+
plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random Classifier")
|
| 170 |
+
|
| 171 |
+
plt.xlabel("False Positive Rate", fontsize=14)
|
| 172 |
+
plt.ylabel("True Positive Rate", fontsize=14)
|
| 173 |
+
plt.title("ROC Curve - Test Set", fontsize=16, fontweight="bold")
|
| 174 |
+
plt.legend(fontsize=12)
|
| 175 |
+
plt.grid(True, alpha=0.3)
|
| 176 |
+
plt.tight_layout()
|
| 177 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 178 |
+
plt.close()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def plot_precision_recall_curve(y_true, y_prob, ap_score, output_path: Path):
|
| 182 |
+
"""Plot Precision-Recall curve."""
|
| 183 |
+
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
|
| 184 |
+
|
| 185 |
+
plt.figure(figsize=(8, 6))
|
| 186 |
+
plt.plot(recall, precision, linewidth=2, label=f"Model (AP = {ap_score:.4f})")
|
| 187 |
+
|
| 188 |
+
plt.xlabel("Recall", fontsize=14)
|
| 189 |
+
plt.ylabel("Precision", fontsize=14)
|
| 190 |
+
plt.title("Precision-Recall Curve - Test Set", fontsize=16, fontweight="bold")
|
| 191 |
+
plt.legend(fontsize=12)
|
| 192 |
+
plt.grid(True, alpha=0.3)
|
| 193 |
+
plt.xlim([0, 1])
|
| 194 |
+
plt.ylim([0, 1])
|
| 195 |
+
plt.tight_layout()
|
| 196 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 197 |
+
plt.close()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def plot_probability_distribution(y_true, y_prob, output_path: Path):
|
| 201 |
+
"""Plot distribution of predicted probabilities by class."""
|
| 202 |
+
plt.figure(figsize=(10, 6))
|
| 203 |
+
|
| 204 |
+
mask_positive = y_true == 1
|
| 205 |
+
mask_negative = y_true == 0
|
| 206 |
+
|
| 207 |
+
plt.hist(
|
| 208 |
+
y_prob[mask_negative],
|
| 209 |
+
bins=30,
|
| 210 |
+
alpha=0.6,
|
| 211 |
+
color="blue",
|
| 212 |
+
label="Healthy",
|
| 213 |
+
edgecolor="black",
|
| 214 |
+
)
|
| 215 |
+
plt.hist(
|
| 216 |
+
y_prob[mask_positive],
|
| 217 |
+
bins=30,
|
| 218 |
+
alpha=0.6,
|
| 219 |
+
color="red",
|
| 220 |
+
label="Dysarthric",
|
| 221 |
+
edgecolor="black",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
plt.axvline(0.5, color="black", linestyle="--", linewidth=2, label="Decision Threshold")
|
| 225 |
+
|
| 226 |
+
plt.xlabel("Predicted Probability", fontsize=14)
|
| 227 |
+
plt.ylabel("Count", fontsize=14)
|
| 228 |
+
plt.title("Predicted Probability Distribution - Test Set", fontsize=16, fontweight="bold")
|
| 229 |
+
plt.legend(fontsize=12)
|
| 230 |
+
plt.grid(True, alpha=0.3, axis="y")
|
| 231 |
+
plt.tight_layout()
|
| 232 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 233 |
+
plt.close()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 237 |
+
# Error Analysis
|
| 238 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 239 |
+
|
| 240 |
+
def perform_error_analysis(y_true, y_pred, y_prob, file_paths, output_path: Path):
|
| 241 |
+
"""Identify and save misclassified samples."""
|
| 242 |
+
errors = []
|
| 243 |
+
|
| 244 |
+
for i, (true_label, pred_label, prob, file_path) in enumerate(
|
| 245 |
+
zip(y_true, y_pred, y_prob, file_paths)
|
| 246 |
+
):
|
| 247 |
+
if true_label != pred_label:
|
| 248 |
+
error_type = "False Positive" if pred_label == 1 else "False Negative"
|
| 249 |
+
confidence = prob if pred_label == 1 else (1 - prob)
|
| 250 |
+
|
| 251 |
+
errors.append({
|
| 252 |
+
"file_path": file_path,
|
| 253 |
+
"true_label": "Dysarthric" if true_label == 1 else "Healthy",
|
| 254 |
+
"predicted_label": "Dysarthric" if pred_label == 1 else "Healthy",
|
| 255 |
+
"probability": prob,
|
| 256 |
+
"confidence": confidence,
|
| 257 |
+
"error_type": error_type,
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
errors_df = pd.DataFrame(errors)
|
| 261 |
+
errors_df = errors_df.sort_values("confidence", ascending=False)
|
| 262 |
+
errors_df.to_csv(output_path, index=False)
|
| 263 |
+
|
| 264 |
+
return errors_df
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 268 |
+
# Main
|
| 269 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 270 |
+
|
| 271 |
+
def main():
|
| 272 |
+
logging.basicConfig(level=logging.INFO)
|
| 273 |
+
logger = logging.getLogger(__name__)
|
| 274 |
+
|
| 275 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 276 |
+
logger.info(f"Using device: {device}")
|
| 277 |
+
|
| 278 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 279 |
+
# Load Configuration
|
| 280 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 281 |
+
with open("configs/model_config.yaml") as f:
|
| 282 |
+
config = yaml.safe_load(f)
|
| 283 |
+
|
| 284 |
+
alpha = config.get("ensemble", {}).get("alpha", 0.6)
|
| 285 |
+
|
| 286 |
+
# Load Platt scaling parameters
|
| 287 |
+
calibration_file = Path("reports/calibration/calibration_params.yaml")
|
| 288 |
+
if calibration_file.exists():
|
| 289 |
+
with open(calibration_file) as f:
|
| 290 |
+
cal_config = yaml.safe_load(f)
|
| 291 |
+
platt_a = cal_config["platt_scaling"]["a"]
|
| 292 |
+
platt_b = cal_config["platt_scaling"]["b"]
|
| 293 |
+
logger.info(f"Loaded Platt parameters: a={platt_a:.6f}, b={platt_b:.6f}")
|
| 294 |
+
else:
|
| 295 |
+
platt_a, platt_b = 1.0, 0.0
|
| 296 |
+
logger.warning("Calibration parameters not found, using identity mapping")
|
| 297 |
+
|
| 298 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 299 |
+
# Load Models
|
| 300 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 301 |
+
logger.info("Loading models...")
|
| 302 |
+
|
| 303 |
+
hubert_checkpoint = Path("models/hubert_salr_best.pt")
|
| 304 |
+
cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
|
| 305 |
+
|
| 306 |
+
hubert_model = HuBERTSALRModel()
|
| 307 |
+
hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"])
|
| 308 |
+
hubert_model.to(device)
|
| 309 |
+
|
| 310 |
+
cnn_model = CNNBiLSTMTransformer()
|
| 311 |
+
cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"])
|
| 312 |
+
cnn_model.to(device)
|
| 313 |
+
|
| 314 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 315 |
+
# Load Test Data
|
| 316 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 317 |
+
test_manifest = Path("data/manifests/test.csv")
|
| 318 |
+
test_dataset = DysarthriaDataset(test_manifest, augmentor=None, mode="test")
|
| 319 |
+
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)
|
| 320 |
+
|
| 321 |
+
logger.info(f"Test samples: {len(test_dataset)}")
|
| 322 |
+
|
| 323 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 324 |
+
# Evaluate
|
| 325 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 326 |
+
mlflow.set_experiment("model_evaluation")
|
| 327 |
+
|
| 328 |
+
with mlflow.start_run():
|
| 329 |
+
logger.info("\nEvaluating on test set...")
|
| 330 |
+
|
| 331 |
+
y_pred, y_prob, y_true, file_paths = evaluate_model(
|
| 332 |
+
hubert_model, cnn_model, test_loader, alpha, platt_a, platt_b, device
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Compute metrics
|
| 336 |
+
metrics = compute_metrics(y_true, y_pred, y_prob)
|
| 337 |
+
|
| 338 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 339 |
+
# Print Results
|
| 340 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 341 |
+
logger.info("\n" + "=" * 80)
|
| 342 |
+
logger.info("TEST SET EVALUATION RESULTS")
|
| 343 |
+
logger.info("=" * 80)
|
| 344 |
+
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
|
| 345 |
+
logger.info(f"F1 Score: {metrics['f1']:.4f}")
|
| 346 |
+
logger.info(f"AUC-ROC: {metrics['auc']:.4f}")
|
| 347 |
+
logger.info(f"Average Precision: {metrics['average_precision']:.4f}")
|
| 348 |
+
logger.info(f"Sensitivity: {metrics['sensitivity']:.4f}")
|
| 349 |
+
logger.info(f"Specificity: {metrics['specificity']:.4f}")
|
| 350 |
+
logger.info(f"PPV: {metrics['ppv']:.4f}")
|
| 351 |
+
logger.info(f"NPV: {metrics['npv']:.4f}")
|
| 352 |
+
logger.info("")
|
| 353 |
+
logger.info("Confusion Matrix:")
|
| 354 |
+
logger.info(f" True Negatives: {metrics['tn']}")
|
| 355 |
+
logger.info(f" False Positives: {metrics['fp']}")
|
| 356 |
+
logger.info(f" False Negatives: {metrics['fn']}")
|
| 357 |
+
logger.info(f" True Positives: {metrics['tp']}")
|
| 358 |
+
logger.info("=" * 80)
|
| 359 |
+
|
| 360 |
+
# Log to MLflow
|
| 361 |
+
mlflow.log_params({
|
| 362 |
+
"ensemble_alpha": alpha,
|
| 363 |
+
"platt_a": platt_a,
|
| 364 |
+
"platt_b": platt_b,
|
| 365 |
+
"test_samples": len(y_true),
|
| 366 |
+
})
|
| 367 |
+
|
| 368 |
+
mlflow.log_metrics({
|
| 369 |
+
"test_accuracy": metrics["accuracy"],
|
| 370 |
+
"test_f1": metrics["f1"],
|
| 371 |
+
"test_auc": metrics["auc"],
|
| 372 |
+
"test_ap": metrics["average_precision"],
|
| 373 |
+
"test_sensitivity": metrics["sensitivity"],
|
| 374 |
+
"test_specificity": metrics["specificity"],
|
| 375 |
+
"test_ppv": metrics["ppv"],
|
| 376 |
+
"test_npv": metrics["npv"],
|
| 377 |
+
})
|
| 378 |
+
|
| 379 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 380 |
+
# Save Results
|
| 381 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 382 |
+
output_dir = Path("reports/evaluation")
|
| 383 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 384 |
+
|
| 385 |
+
# Save metrics
|
| 386 |
+
metrics_file = output_dir / "test_metrics.yaml"
|
| 387 |
+
with open(metrics_file, "w") as f:
|
| 388 |
+
# Convert numpy types to Python types
|
| 389 |
+
metrics_to_save = {k: v for k, v in metrics.items() if k != "confusion_matrix"}
|
| 390 |
+
yaml.dump(metrics_to_save, f, default_flow_style=False)
|
| 391 |
+
mlflow.log_artifact(str(metrics_file))
|
| 392 |
+
logger.info(f"\n✓ Metrics saved to {metrics_file}")
|
| 393 |
+
|
| 394 |
+
# Classification report
|
| 395 |
+
report = classification_report(
|
| 396 |
+
y_true,
|
| 397 |
+
y_pred,
|
| 398 |
+
target_names=["Healthy", "Dysarthric"],
|
| 399 |
+
digits=4,
|
| 400 |
+
)
|
| 401 |
+
report_file = output_dir / "classification_report.txt"
|
| 402 |
+
with open(report_file, "w") as f:
|
| 403 |
+
f.write(report)
|
| 404 |
+
mlflow.log_artifact(str(report_file))
|
| 405 |
+
logger.info(f"✓ Classification report saved to {report_file}")
|
| 406 |
+
|
| 407 |
+
# Confusion matrix
|
| 408 |
+
cm_path = output_dir / "confusion_matrix.png"
|
| 409 |
+
plot_confusion_matrix(metrics["confusion_matrix"], cm_path)
|
| 410 |
+
mlflow.log_artifact(str(cm_path))
|
| 411 |
+
logger.info(f"✓ Confusion matrix plot saved to {cm_path}")
|
| 412 |
+
|
| 413 |
+
# ROC curve
|
| 414 |
+
roc_path = output_dir / "roc_curve.png"
|
| 415 |
+
plot_roc_curve(y_true, y_prob, metrics["auc"], roc_path)
|
| 416 |
+
mlflow.log_artifact(str(roc_path))
|
| 417 |
+
logger.info(f"✓ ROC curve saved to {roc_path}")
|
| 418 |
+
|
| 419 |
+
# Precision-Recall curve
|
| 420 |
+
pr_path = output_dir / "precision_recall_curve.png"
|
| 421 |
+
plot_precision_recall_curve(y_true, y_prob, metrics["average_precision"], pr_path)
|
| 422 |
+
mlflow.log_artifact(str(pr_path))
|
| 423 |
+
logger.info(f"✓ Precision-Recall curve saved to {pr_path}")
|
| 424 |
+
|
| 425 |
+
# Probability distribution
|
| 426 |
+
prob_dist_path = output_dir / "probability_distribution.png"
|
| 427 |
+
plot_probability_distribution(y_true, y_prob, prob_dist_path)
|
| 428 |
+
mlflow.log_artifact(str(prob_dist_path))
|
| 429 |
+
logger.info(f"✓ Probability distribution saved to {prob_dist_path}")
|
| 430 |
+
|
| 431 |
+
# Error analysis
|
| 432 |
+
errors_file = output_dir / "misclassified_samples.csv"
|
| 433 |
+
errors_df = perform_error_analysis(y_true, y_pred, y_prob, file_paths, errors_file)
|
| 434 |
+
mlflow.log_artifact(str(errors_file))
|
| 435 |
+
logger.info(f"✓ Error analysis saved to {errors_file}")
|
| 436 |
+
logger.info(f" Total errors: {len(errors_df)}")
|
| 437 |
+
logger.info(f" False Positives: {len(errors_df[errors_df['error_type'] == 'False Positive'])}")
|
| 438 |
+
logger.info(f" False Negatives: {len(errors_df[errors_df['error_type'] == 'False Negative'])}")
|
| 439 |
+
|
| 440 |
+
logger.info("\n✓ Evaluation complete!")
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if __name__ == "__main__":
|
| 444 |
+
main()
|
training/train_cnn_bilstm.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for CNN-BiLSTM-Transformer model (spectrogram branch).
|
| 3 |
+
|
| 4 |
+
This model processes log-mel spectrograms and CWT scalograms through:
|
| 5 |
+
1. CNN feature extraction (ResNet-style blocks)
|
| 6 |
+
2. BiLSTM temporal modeling
|
| 7 |
+
3. Transformer encoder with self-attention
|
| 8 |
+
4. Classification head
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python training/train_cnn_bilstm.py
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
import mlflow
|
| 23 |
+
import numpy as np
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
import yaml
|
| 26 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
| 27 |
+
import logging
|
| 28 |
+
|
| 29 |
+
from training.dataset import DysarthriaDataset
|
| 30 |
+
from training.augmentation import AudioAugmentor
|
| 31 |
+
|
| 32 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 33 |
+
# CNN-BiLSTM-Transformer Model Architecture
|
| 34 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 35 |
+
|
| 36 |
+
class ResidualBlock(nn.Module):
|
| 37 |
+
"""Residual block for CNN feature extraction."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 42 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 43 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 44 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 45 |
+
|
| 46 |
+
# Skip connection with 1x1 conv if dimensions change
|
| 47 |
+
self.skip = nn.Identity()
|
| 48 |
+
if in_channels != out_channels:
|
| 49 |
+
self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
residual = self.skip(x)
|
| 53 |
+
x = torch.relu(self.bn1(self.conv1(x)))
|
| 54 |
+
x = self.bn2(self.conv2(x))
|
| 55 |
+
return torch.relu(x + residual)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CNNBiLSTMTransformer(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Spectrogram-based dysarthria detection model.
|
| 61 |
+
|
| 62 |
+
Architecture:
|
| 63 |
+
- CNN: Extract spatial features from spectrogram
|
| 64 |
+
- BiLSTM: Model temporal dependencies
|
| 65 |
+
- Transformer: Self-attention for long-range patterns
|
| 66 |
+
- Classifier: Binary classification head
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
input_channels: int = 2, # Log-mel + CWT
|
| 72 |
+
cnn_channels: list = [64, 128, 256],
|
| 73 |
+
lstm_hidden: int = 256,
|
| 74 |
+
transformer_heads: int = 8,
|
| 75 |
+
transformer_layers: int = 4,
|
| 76 |
+
dropout: float = 0.3,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 81 |
+
# CNN Feature Extractor
|
| 82 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 83 |
+
self.cnn_blocks = nn.ModuleList()
|
| 84 |
+
in_ch = input_channels
|
| 85 |
+
for out_ch in cnn_channels:
|
| 86 |
+
self.cnn_blocks.append(ResidualBlock(in_ch, out_ch))
|
| 87 |
+
in_ch = out_ch
|
| 88 |
+
|
| 89 |
+
self.pool = nn.AdaptiveAvgPool2d((None, 1)) # Pool frequency dimension
|
| 90 |
+
|
| 91 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 92 |
+
# BiLSTM Temporal Modeling
|
| 93 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 94 |
+
self.lstm = nn.LSTM(
|
| 95 |
+
input_size=cnn_channels[-1],
|
| 96 |
+
hidden_size=lstm_hidden,
|
| 97 |
+
num_layers=2,
|
| 98 |
+
batch_first=True,
|
| 99 |
+
bidirectional=True,
|
| 100 |
+
dropout=dropout,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 104 |
+
# Transformer Encoder
|
| 105 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 106 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 107 |
+
d_model=lstm_hidden * 2, # Bidirectional
|
| 108 |
+
nhead=transformer_heads,
|
| 109 |
+
dim_feedforward=lstm_hidden * 4,
|
| 110 |
+
dropout=dropout,
|
| 111 |
+
batch_first=True,
|
| 112 |
+
)
|
| 113 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
| 114 |
+
|
| 115 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 116 |
+
# Classification Head
|
| 117 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 118 |
+
self.classifier = nn.Sequential(
|
| 119 |
+
nn.Linear(lstm_hidden * 2, 512),
|
| 120 |
+
nn.ReLU(),
|
| 121 |
+
nn.Dropout(dropout),
|
| 122 |
+
nn.Linear(512, 256),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.Dropout(dropout),
|
| 125 |
+
nn.Linear(256, 2), # Binary: healthy vs dysarthric
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(self, spectrogram):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
spectrogram: (batch, 2, freq, time) - Log-mel + CWT
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
logits: (batch, 2)
|
| 135 |
+
attention_weights: Transformer attention for explainability
|
| 136 |
+
"""
|
| 137 |
+
batch_size = spectrogram.size(0)
|
| 138 |
+
|
| 139 |
+
# CNN feature extraction
|
| 140 |
+
x = spectrogram
|
| 141 |
+
for block in self.cnn_blocks:
|
| 142 |
+
x = block(x)
|
| 143 |
+
|
| 144 |
+
# Pool frequency dimension: (batch, channels, freq, time) → (batch, channels, time)
|
| 145 |
+
x = self.pool(x).squeeze(2)
|
| 146 |
+
|
| 147 |
+
# Transpose for LSTM: (batch, time, channels)
|
| 148 |
+
x = x.transpose(1, 2)
|
| 149 |
+
|
| 150 |
+
# BiLSTM
|
| 151 |
+
x, _ = self.lstm(x)
|
| 152 |
+
|
| 153 |
+
# Transformer encoder
|
| 154 |
+
x = self.transformer(x)
|
| 155 |
+
|
| 156 |
+
# Global average pooling over time
|
| 157 |
+
x = x.mean(dim=1) # (batch, lstm_hidden*2)
|
| 158 |
+
|
| 159 |
+
# Classification
|
| 160 |
+
logits = self.classifier(x)
|
| 161 |
+
|
| 162 |
+
return logits
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 166 |
+
# Training Loop
|
| 167 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 168 |
+
|
| 169 |
+
def train_epoch(model, dataloader, optimizer, criterion, device):
|
| 170 |
+
"""Train for one epoch."""
|
| 171 |
+
model.train()
|
| 172 |
+
total_loss = 0
|
| 173 |
+
all_preds = []
|
| 174 |
+
all_labels = []
|
| 175 |
+
|
| 176 |
+
for batch in tqdm(dataloader, desc="Training"):
|
| 177 |
+
spectrogram = batch["spectrogram"].to(device)
|
| 178 |
+
labels = batch["label"].to(device)
|
| 179 |
+
|
| 180 |
+
# Forward pass
|
| 181 |
+
optimizer.zero_grad()
|
| 182 |
+
logits = model(spectrogram)
|
| 183 |
+
loss = criterion(logits, labels)
|
| 184 |
+
|
| 185 |
+
# Backward pass
|
| 186 |
+
loss.backward()
|
| 187 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 188 |
+
optimizer.step()
|
| 189 |
+
|
| 190 |
+
# Metrics
|
| 191 |
+
total_loss += loss.item()
|
| 192 |
+
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| 193 |
+
all_preds.extend(preds)
|
| 194 |
+
all_labels.extend(labels.cpu().numpy())
|
| 195 |
+
|
| 196 |
+
avg_loss = total_loss / len(dataloader)
|
| 197 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 198 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
| 199 |
+
|
| 200 |
+
return avg_loss, accuracy, f1
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def validate(model, dataloader, criterion, device):
|
| 204 |
+
"""Validate the model."""
|
| 205 |
+
model.eval()
|
| 206 |
+
total_loss = 0
|
| 207 |
+
all_preds = []
|
| 208 |
+
all_probs = []
|
| 209 |
+
all_labels = []
|
| 210 |
+
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
for batch in tqdm(dataloader, desc="Validating"):
|
| 213 |
+
spectrogram = batch["spectrogram"].to(device)
|
| 214 |
+
labels = batch["label"].to(device)
|
| 215 |
+
|
| 216 |
+
logits = model(spectrogram)
|
| 217 |
+
loss = criterion(logits, labels)
|
| 218 |
+
|
| 219 |
+
total_loss += loss.item()
|
| 220 |
+
probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
|
| 221 |
+
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| 222 |
+
|
| 223 |
+
all_preds.extend(preds)
|
| 224 |
+
all_probs.extend(probs)
|
| 225 |
+
all_labels.extend(labels.cpu().numpy())
|
| 226 |
+
|
| 227 |
+
avg_loss = total_loss / len(dataloader)
|
| 228 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 229 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
| 230 |
+
auc = roc_auc_score(all_labels, all_probs)
|
| 231 |
+
|
| 232 |
+
return avg_loss, accuracy, f1, auc
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def main():
|
| 236 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 237 |
+
# Setup
|
| 238 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 239 |
+
logging.basicConfig(level=logging.INFO)
|
| 240 |
+
logger = logging.getLogger(__name__)
|
| 241 |
+
|
| 242 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 243 |
+
logger.info(f"Using device: {device}")
|
| 244 |
+
|
| 245 |
+
# Load config
|
| 246 |
+
config_path = Path("configs/model_config.yaml")
|
| 247 |
+
with open(config_path) as f:
|
| 248 |
+
config = yaml.safe_load(f)
|
| 249 |
+
|
| 250 |
+
# MLflow setup
|
| 251 |
+
mlflow.set_experiment("cnn_bilstm_transformer_training")
|
| 252 |
+
|
| 253 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 254 |
+
# Data Loading
|
| 255 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 256 |
+
train_manifest = Path("data/manifests/train.csv")
|
| 257 |
+
val_manifest = Path("data/manifests/val.csv")
|
| 258 |
+
|
| 259 |
+
augmentor = AudioAugmentor(
|
| 260 |
+
time_stretch_range=(0.9, 1.1),
|
| 261 |
+
pitch_shift_range=(-2, 2),
|
| 262 |
+
noise_level=0.005,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
train_dataset = DysarthriaDataset(train_manifest, augmentor=augmentor, mode="train")
|
| 266 |
+
val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
|
| 267 |
+
|
| 268 |
+
train_loader = DataLoader(
|
| 269 |
+
train_dataset,
|
| 270 |
+
batch_size=config.get("cnn_bilstm", {}).get("batch_size", 16),
|
| 271 |
+
shuffle=True,
|
| 272 |
+
num_workers=4,
|
| 273 |
+
pin_memory=True,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
val_loader = DataLoader(
|
| 277 |
+
val_dataset,
|
| 278 |
+
batch_size=config.get("cnn_bilstm", {}).get("batch_size", 16),
|
| 279 |
+
shuffle=False,
|
| 280 |
+
num_workers=4,
|
| 281 |
+
pin_memory=True,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
| 285 |
+
|
| 286 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 287 |
+
# Model Setup
|
| 288 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 289 |
+
model = CNNBiLSTMTransformer(
|
| 290 |
+
input_channels=2,
|
| 291 |
+
cnn_channels=[64, 128, 256],
|
| 292 |
+
lstm_hidden=256,
|
| 293 |
+
transformer_heads=8,
|
| 294 |
+
transformer_layers=4,
|
| 295 |
+
dropout=0.3,
|
| 296 |
+
).to(device)
|
| 297 |
+
|
| 298 |
+
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 299 |
+
|
| 300 |
+
# Optimizer and scheduler
|
| 301 |
+
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
|
| 302 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 303 |
+
optimizer, mode="min", factor=0.5, patience=3, verbose=True
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Loss function with class weights (handle imbalance)
|
| 307 |
+
criterion = nn.CrossEntropyLoss()
|
| 308 |
+
|
| 309 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 310 |
+
# Training Loop
|
| 311 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 312 |
+
num_epochs = config.get("cnn_bilstm", {}).get("epochs", 30)
|
| 313 |
+
best_val_auc = 0
|
| 314 |
+
best_model_path = Path("models/cnn_bilstm_best.pt")
|
| 315 |
+
best_model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 316 |
+
|
| 317 |
+
with mlflow.start_run():
|
| 318 |
+
# Log hyperparameters
|
| 319 |
+
mlflow.log_params({
|
| 320 |
+
"model": "cnn_bilstm_transformer",
|
| 321 |
+
"epochs": num_epochs,
|
| 322 |
+
"batch_size": config.get("cnn_bilstm", {}).get("batch_size", 16),
|
| 323 |
+
"learning_rate": 1e-4,
|
| 324 |
+
"optimizer": "AdamW",
|
| 325 |
+
})
|
| 326 |
+
|
| 327 |
+
for epoch in range(1, num_epochs + 1):
|
| 328 |
+
logger.info(f"\nEpoch {epoch}/{num_epochs}")
|
| 329 |
+
|
| 330 |
+
# Train
|
| 331 |
+
train_loss, train_acc, train_f1 = train_epoch(
|
| 332 |
+
model, train_loader, optimizer, criterion, device
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Validate
|
| 336 |
+
val_loss, val_acc, val_f1, val_auc = validate(
|
| 337 |
+
model, val_loader, criterion, device
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Learning rate scheduling
|
| 341 |
+
scheduler.step(val_loss)
|
| 342 |
+
|
| 343 |
+
# Logging
|
| 344 |
+
logger.info(
|
| 345 |
+
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}"
|
| 346 |
+
)
|
| 347 |
+
logger.info(
|
| 348 |
+
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, AUC: {val_auc:.4f}"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
mlflow.log_metrics({
|
| 352 |
+
"train_loss": train_loss,
|
| 353 |
+
"train_accuracy": train_acc,
|
| 354 |
+
"train_f1": train_f1,
|
| 355 |
+
"val_loss": val_loss,
|
| 356 |
+
"val_accuracy": val_acc,
|
| 357 |
+
"val_f1": val_f1,
|
| 358 |
+
"val_auc": val_auc,
|
| 359 |
+
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 360 |
+
}, step=epoch)
|
| 361 |
+
|
| 362 |
+
# Save best model
|
| 363 |
+
if val_auc > best_val_auc:
|
| 364 |
+
best_val_auc = val_auc
|
| 365 |
+
torch.save({
|
| 366 |
+
"epoch": epoch,
|
| 367 |
+
"model_state_dict": model.state_dict(),
|
| 368 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 369 |
+
"val_auc": val_auc,
|
| 370 |
+
}, best_model_path)
|
| 371 |
+
logger.info(f"✓ New best model saved (AUC: {val_auc:.4f})")
|
| 372 |
+
mlflow.log_artifact(str(best_model_path))
|
| 373 |
+
|
| 374 |
+
logger.info(f"\n✓ Training complete! Best validation AUC: {best_val_auc:.4f}")
|
| 375 |
+
mlflow.log_metric("best_val_auc", best_val_auc)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
if __name__ == "__main__":
|
| 379 |
+
main()
|
training/train_ensemble_weights.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimize ensemble weights between HuBERT-SALR and CNN-BiLSTM models.
|
| 3 |
+
|
| 4 |
+
This script performs grid search to find the optimal alpha (mixing weight):
|
| 5 |
+
ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python training/train_ensemble_weights.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
import mlflow
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import yaml
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
|
| 24 |
+
import logging
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import seaborn as sns
|
| 27 |
+
|
| 28 |
+
from training.dataset import DysarthriaDataset
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 32 |
+
# Model Loading Utilities
|
| 33 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 34 |
+
|
| 35 |
+
def load_hubert_salr(checkpoint_path: Path, device):
|
| 36 |
+
"""Load trained HuBERT-SALR model."""
|
| 37 |
+
from training.train_hubert_salr import HuBERTSALRModel
|
| 38 |
+
|
| 39 |
+
model = HuBERTSALRModel()
|
| 40 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 41 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 42 |
+
model.to(device)
|
| 43 |
+
model.eval()
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_cnn_bilstm(checkpoint_path: Path, device):
|
| 48 |
+
"""Load trained CNN-BiLSTM model."""
|
| 49 |
+
from training.train_cnn_bilstm import CNNBiLSTMTransformer
|
| 50 |
+
|
| 51 |
+
model = CNNBiLSTMTransformer()
|
| 52 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 53 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 54 |
+
model.to(device)
|
| 55 |
+
model.eval()
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 60 |
+
# Ensemble Evaluation
|
| 61 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 62 |
+
|
| 63 |
+
def evaluate_ensemble(
|
| 64 |
+
hubert_model,
|
| 65 |
+
cnn_model,
|
| 66 |
+
dataloader,
|
| 67 |
+
alpha: float,
|
| 68 |
+
device,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Evaluate ensemble with given alpha weight.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
hubert_model: HuBERT-SALR model
|
| 75 |
+
cnn_model: CNN-BiLSTM model
|
| 76 |
+
dataloader: Validation data
|
| 77 |
+
alpha: Mixing weight (0 to 1)
|
| 78 |
+
device: torch device
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Dict of metrics
|
| 82 |
+
"""
|
| 83 |
+
all_preds = []
|
| 84 |
+
all_probs = []
|
| 85 |
+
all_labels = []
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for batch in tqdm(dataloader, desc=f"Alpha={alpha:.2f}", leave=False):
|
| 89 |
+
waveform = batch["waveform"].to(device)
|
| 90 |
+
spectrogram = batch["spectrogram"].to(device)
|
| 91 |
+
labels = batch["label"].to(device)
|
| 92 |
+
|
| 93 |
+
# Get predictions from both models
|
| 94 |
+
hubert_logits = hubert_model(waveform)
|
| 95 |
+
cnn_logits = cnn_model(spectrogram)
|
| 96 |
+
|
| 97 |
+
# Ensemble
|
| 98 |
+
ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
|
| 99 |
+
|
| 100 |
+
# Convert to predictions
|
| 101 |
+
probs = torch.softmax(ensemble_logits, dim=1)[:, 1].cpu().numpy()
|
| 102 |
+
preds = torch.argmax(ensemble_logits, dim=1).cpu().numpy()
|
| 103 |
+
|
| 104 |
+
all_preds.extend(preds)
|
| 105 |
+
all_probs.extend(probs)
|
| 106 |
+
all_labels.extend(labels.cpu().numpy())
|
| 107 |
+
|
| 108 |
+
# Compute metrics
|
| 109 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 110 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
| 111 |
+
auc = roc_auc_score(all_labels, all_probs)
|
| 112 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 113 |
+
|
| 114 |
+
# Compute sensitivity and specificity
|
| 115 |
+
tn, fp, fn, tp = cm.ravel()
|
| 116 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 117 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"alpha": alpha,
|
| 121 |
+
"accuracy": accuracy,
|
| 122 |
+
"f1": f1,
|
| 123 |
+
"auc": auc,
|
| 124 |
+
"sensitivity": sensitivity,
|
| 125 |
+
"specificity": specificity,
|
| 126 |
+
"confusion_matrix": cm,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 131 |
+
# Grid Search
|
| 132 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 133 |
+
|
| 134 |
+
def grid_search_alpha(
|
| 135 |
+
hubert_model,
|
| 136 |
+
cnn_model,
|
| 137 |
+
dataloader,
|
| 138 |
+
device,
|
| 139 |
+
alpha_range=(0.0, 1.0),
|
| 140 |
+
num_points=21,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Perform grid search over alpha values.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
hubert_model: HuBERT-SALR model
|
| 147 |
+
cnn_model: CNN-BiLSTM model
|
| 148 |
+
dataloader: Validation data
|
| 149 |
+
device: torch device
|
| 150 |
+
alpha_range: (min, max) alpha values
|
| 151 |
+
num_points: Number of alpha values to test
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
DataFrame with results for each alpha
|
| 155 |
+
"""
|
| 156 |
+
alphas = np.linspace(alpha_range[0], alpha_range[1], num_points)
|
| 157 |
+
results = []
|
| 158 |
+
|
| 159 |
+
for alpha in alphas:
|
| 160 |
+
metrics = evaluate_ensemble(hubert_model, cnn_model, dataloader, alpha, device)
|
| 161 |
+
results.append(metrics)
|
| 162 |
+
|
| 163 |
+
return pd.DataFrame(results)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 167 |
+
# Visualization
|
| 168 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 169 |
+
|
| 170 |
+
def plot_alpha_search(results_df, output_path: Path):
|
| 171 |
+
"""Plot metrics vs alpha."""
|
| 172 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 173 |
+
|
| 174 |
+
metrics = ["accuracy", "f1", "auc", "sensitivity"]
|
| 175 |
+
titles = ["Accuracy", "F1 Score", "AUC-ROC", "Sensitivity"]
|
| 176 |
+
|
| 177 |
+
for ax, metric, title in zip(axes.flat, metrics, titles):
|
| 178 |
+
ax.plot(results_df["alpha"], results_df[metric], marker="o", linewidth=2)
|
| 179 |
+
ax.set_xlabel("Alpha (HuBERT weight)", fontsize=12)
|
| 180 |
+
ax.set_ylabel(title, fontsize=12)
|
| 181 |
+
ax.set_title(f"{title} vs Alpha", fontsize=14)
|
| 182 |
+
ax.grid(True, alpha=0.3)
|
| 183 |
+
|
| 184 |
+
# Mark best alpha
|
| 185 |
+
best_idx = results_df[metric].idxmax()
|
| 186 |
+
best_alpha = results_df.loc[best_idx, "alpha"]
|
| 187 |
+
best_value = results_df.loc[best_idx, metric]
|
| 188 |
+
ax.axvline(best_alpha, color="red", linestyle="--", alpha=0.5)
|
| 189 |
+
ax.scatter([best_alpha], [best_value], color="red", s=100, zorder=5)
|
| 190 |
+
ax.text(
|
| 191 |
+
best_alpha,
|
| 192 |
+
best_value,
|
| 193 |
+
f"α={best_alpha:.2f}\n{best_value:.4f}",
|
| 194 |
+
ha="center",
|
| 195 |
+
va="bottom",
|
| 196 |
+
fontsize=10,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
plt.tight_layout()
|
| 200 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 201 |
+
plt.close()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def plot_confusion_matrix(cm, alpha, output_path: Path):
|
| 205 |
+
"""Plot confusion matrix for best alpha."""
|
| 206 |
+
plt.figure(figsize=(8, 6))
|
| 207 |
+
sns.heatmap(
|
| 208 |
+
cm,
|
| 209 |
+
annot=True,
|
| 210 |
+
fmt="d",
|
| 211 |
+
cmap="Blues",
|
| 212 |
+
xticklabels=["Healthy", "Dysarthric"],
|
| 213 |
+
yticklabels=["Healthy", "Dysarthric"],
|
| 214 |
+
)
|
| 215 |
+
plt.title(f"Confusion Matrix (α={alpha:.2f})", fontsize=14)
|
| 216 |
+
plt.ylabel("True Label", fontsize=12)
|
| 217 |
+
plt.xlabel("Predicted Label", fontsize=12)
|
| 218 |
+
plt.tight_layout()
|
| 219 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 220 |
+
plt.close()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 224 |
+
# Main
|
| 225 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 226 |
+
|
| 227 |
+
def main():
|
| 228 |
+
logging.basicConfig(level=logging.INFO)
|
| 229 |
+
logger = logging.getLogger(__name__)
|
| 230 |
+
|
| 231 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 232 |
+
logger.info(f"Using device: {device}")
|
| 233 |
+
|
| 234 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 235 |
+
# Load Models
|
| 236 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 237 |
+
hubert_checkpoint = Path("models/hubert_salr_best.pt")
|
| 238 |
+
cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
|
| 239 |
+
|
| 240 |
+
if not hubert_checkpoint.exists():
|
| 241 |
+
logger.error(f"HuBERT checkpoint not found: {hubert_checkpoint}")
|
| 242 |
+
logger.error("Please train HuBERT-SALR first: python training/train_hubert_salr.py")
|
| 243 |
+
return
|
| 244 |
+
|
| 245 |
+
if not cnn_checkpoint.exists():
|
| 246 |
+
logger.error(f"CNN-BiLSTM checkpoint not found: {cnn_checkpoint}")
|
| 247 |
+
logger.error("Please train CNN-BiLSTM first: python training/train_cnn_bilstm.py")
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
logger.info("Loading HuBERT-SALR model...")
|
| 251 |
+
hubert_model = load_hubert_salr(hubert_checkpoint, device)
|
| 252 |
+
|
| 253 |
+
logger.info("Loading CNN-BiLSTM model...")
|
| 254 |
+
cnn_model = load_cnn_bilstm(cnn_checkpoint, device)
|
| 255 |
+
|
| 256 |
+
# ─────────────────────��────────────────────────────────────────────────────
|
| 257 |
+
# Load Validation Data
|
| 258 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 259 |
+
val_manifest = Path("data/manifests/val.csv")
|
| 260 |
+
val_dataset = DysarthriaDataset(val_manifest, augmentor=None, mode="val")
|
| 261 |
+
val_loader = DataLoader(
|
| 262 |
+
val_dataset,
|
| 263 |
+
batch_size=16,
|
| 264 |
+
shuffle=False,
|
| 265 |
+
num_workers=4,
|
| 266 |
+
pin_memory=True,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
logger.info(f"Validation samples: {len(val_dataset)}")
|
| 270 |
+
|
| 271 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 272 |
+
# Grid Search
|
| 273 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 274 |
+
mlflow.set_experiment("ensemble_weight_optimization")
|
| 275 |
+
|
| 276 |
+
with mlflow.start_run():
|
| 277 |
+
logger.info("\nStarting grid search over alpha values...")
|
| 278 |
+
|
| 279 |
+
results_df = grid_search_alpha(
|
| 280 |
+
hubert_model,
|
| 281 |
+
cnn_model,
|
| 282 |
+
val_loader,
|
| 283 |
+
device,
|
| 284 |
+
alpha_range=(0.0, 1.0),
|
| 285 |
+
num_points=21,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Find best alpha for each metric
|
| 289 |
+
best_alpha_auc = results_df.loc[results_df["auc"].idxmax(), "alpha"]
|
| 290 |
+
best_alpha_f1 = results_df.loc[results_df["f1"].idxmax(), "alpha"]
|
| 291 |
+
best_alpha_acc = results_df.loc[results_df["accuracy"].idxmax(), "alpha"]
|
| 292 |
+
|
| 293 |
+
logger.info("\n" + "=" * 80)
|
| 294 |
+
logger.info("GRID SEARCH RESULTS")
|
| 295 |
+
logger.info("=" * 80)
|
| 296 |
+
logger.info(f"Best alpha (AUC): {best_alpha_auc:.2f}")
|
| 297 |
+
logger.info(f"Best alpha (F1): {best_alpha_f1:.2f}")
|
| 298 |
+
logger.info(f"Best alpha (Accuracy): {best_alpha_acc:.2f}")
|
| 299 |
+
logger.info("=" * 80)
|
| 300 |
+
|
| 301 |
+
# Use AUC as primary metric
|
| 302 |
+
best_alpha = best_alpha_auc
|
| 303 |
+
best_row = results_df.loc[results_df["alpha"] == best_alpha].iloc[0]
|
| 304 |
+
|
| 305 |
+
logger.info(f"\nOptimal alpha: {best_alpha:.2f}")
|
| 306 |
+
logger.info(f" Accuracy: {best_row['accuracy']:.4f}")
|
| 307 |
+
logger.info(f" F1 Score: {best_row['f1']:.4f}")
|
| 308 |
+
logger.info(f" AUC: {best_row['auc']:.4f}")
|
| 309 |
+
logger.info(f" Sensitivity: {best_row['sensitivity']:.4f}")
|
| 310 |
+
logger.info(f" Specificity: {best_row['specificity']:.4f}")
|
| 311 |
+
|
| 312 |
+
# Log to MLflow
|
| 313 |
+
mlflow.log_params({
|
| 314 |
+
"num_alpha_points": 21,
|
| 315 |
+
"alpha_range_min": 0.0,
|
| 316 |
+
"alpha_range_max": 1.0,
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
mlflow.log_metrics({
|
| 320 |
+
"best_alpha": best_alpha,
|
| 321 |
+
"best_accuracy": best_row["accuracy"],
|
| 322 |
+
"best_f1": best_row["f1"],
|
| 323 |
+
"best_auc": best_row["auc"],
|
| 324 |
+
"best_sensitivity": best_row["sensitivity"],
|
| 325 |
+
"best_specificity": best_row["specificity"],
|
| 326 |
+
})
|
| 327 |
+
|
| 328 |
+
# Save results
|
| 329 |
+
output_dir = Path("reports/ensemble_optimization")
|
| 330 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 331 |
+
|
| 332 |
+
results_csv = output_dir / "alpha_search_results.csv"
|
| 333 |
+
results_df.to_csv(results_csv, index=False)
|
| 334 |
+
mlflow.log_artifact(str(results_csv))
|
| 335 |
+
logger.info(f"\n✓ Results saved to {results_csv}")
|
| 336 |
+
|
| 337 |
+
# Plot metrics vs alpha
|
| 338 |
+
plot_path = output_dir / "alpha_search_plot.png"
|
| 339 |
+
plot_alpha_search(results_df, plot_path)
|
| 340 |
+
mlflow.log_artifact(str(plot_path))
|
| 341 |
+
logger.info(f"✓ Plots saved to {plot_path}")
|
| 342 |
+
|
| 343 |
+
# Plot confusion matrix for best alpha
|
| 344 |
+
cm_path = output_dir / "confusion_matrix_best_alpha.png"
|
| 345 |
+
plot_confusion_matrix(best_row["confusion_matrix"], best_alpha, cm_path)
|
| 346 |
+
mlflow.log_artifact(str(cm_path))
|
| 347 |
+
logger.info(f"✓ Confusion matrix saved to {cm_path}")
|
| 348 |
+
|
| 349 |
+
# Save optimal config
|
| 350 |
+
optimal_config = {
|
| 351 |
+
"ensemble": {
|
| 352 |
+
"alpha": float(best_alpha),
|
| 353 |
+
"hubert_weight": float(best_alpha),
|
| 354 |
+
"cnn_bilstm_weight": float(1 - best_alpha),
|
| 355 |
+
"validation_metrics": {
|
| 356 |
+
"accuracy": float(best_row["accuracy"]),
|
| 357 |
+
"f1": float(best_row["f1"]),
|
| 358 |
+
"auc": float(best_row["auc"]),
|
| 359 |
+
"sensitivity": float(best_row["sensitivity"]),
|
| 360 |
+
"specificity": float(best_row["specificity"]),
|
| 361 |
+
},
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
config_path = output_dir / "optimal_ensemble_config.yaml"
|
| 366 |
+
with open(config_path, "w") as f:
|
| 367 |
+
yaml.dump(optimal_config, f, default_flow_style=False)
|
| 368 |
+
mlflow.log_artifact(str(config_path))
|
| 369 |
+
logger.info(f"✓ Optimal config saved to {config_path}")
|
| 370 |
+
|
| 371 |
+
logger.info("\n✓ Ensemble weight optimization complete!")
|
| 372 |
+
logger.info(f" Update configs/model_config.yaml with alpha={best_alpha:.2f}")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
main()
|
training/train_hubert_fast.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Fast fine-tuning script for HuBERT-SALR model.
|
| 4 |
+
|
| 5 |
+
Optimizations:
|
| 6 |
+
- Reduced dataset size (500-1000 samples)
|
| 7 |
+
- Fewer epochs (5 instead of 20)
|
| 8 |
+
- Simplified model architecture
|
| 9 |
+
- Uses MPS/GPU acceleration
|
| 10 |
+
- Faster feature extraction
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python training/train_hubert_fast.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.optim as optim
|
| 23 |
+
from torch.utils.data import DataLoader, Subset
|
| 24 |
+
import numpy as np
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
import logging
|
| 27 |
+
import pandas as pd
|
| 28 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
| 29 |
+
from transformers import HubertModel
|
| 30 |
+
|
| 31 |
+
logging.basicConfig(level=logging.INFO)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 36 |
+
# Simplified HuBERT Model
|
| 37 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
|
| 39 |
+
class SimplifiedHuBERTClassifier(nn.Module):
|
| 40 |
+
"""Simplified HuBERT for faster training."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, freeze_base=True):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
# Load pre-trained HuBERT (smaller version for speed)
|
| 46 |
+
logger.info("Loading HuBERT-base model...")
|
| 47 |
+
self.hubert = HubertModel.from_pretrained("facebook/hubert-base-ls960")
|
| 48 |
+
|
| 49 |
+
# Freeze base model for faster training
|
| 50 |
+
if freeze_base:
|
| 51 |
+
for param in self.hubert.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
logger.info("✓ HuBERT base frozen (only training classifier)")
|
| 54 |
+
|
| 55 |
+
# Simple classifier head
|
| 56 |
+
hidden_size = self.hubert.config.hidden_size # 768 for base
|
| 57 |
+
self.classifier = nn.Sequential(
|
| 58 |
+
nn.Linear(hidden_size, 256),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Dropout(0.3),
|
| 61 |
+
nn.Linear(256, 2), # Binary: healthy vs dysarthric
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, input_values):
|
| 65 |
+
# Extract features
|
| 66 |
+
with torch.no_grad() if self.training else torch.enable_grad():
|
| 67 |
+
outputs = self.hubert(input_values)
|
| 68 |
+
|
| 69 |
+
# Pool: mean across time dimension
|
| 70 |
+
hidden_states = outputs.last_hidden_state # (batch, time, hidden)
|
| 71 |
+
pooled = hidden_states.mean(dim=1) # (batch, hidden)
|
| 72 |
+
|
| 73 |
+
# Classify
|
| 74 |
+
logits = self.classifier(pooled)
|
| 75 |
+
return logits
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 79 |
+
# Fast Dataset (No Heavy Feature Extraction)
|
| 80 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 81 |
+
|
| 82 |
+
class FastDysarthriaDataset(torch.utils.data.Dataset):
|
| 83 |
+
"""Simplified dataset for fast training."""
|
| 84 |
+
|
| 85 |
+
def __init__(self, manifest_path, max_duration=10.0, sample_rate=16000):
|
| 86 |
+
self.manifest = pd.read_csv(manifest_path)
|
| 87 |
+
self.max_duration = max_duration
|
| 88 |
+
self.sample_rate = sample_rate
|
| 89 |
+
self.max_length = int(max_duration * sample_rate)
|
| 90 |
+
|
| 91 |
+
# Filter valid files
|
| 92 |
+
self.manifest = self.manifest[
|
| 93 |
+
(self.manifest['duration'] >= 5.0) & # Min duration
|
| 94 |
+
(self.manifest['duration'] <= max_duration) # Max duration
|
| 95 |
+
].reset_index(drop=True)
|
| 96 |
+
|
| 97 |
+
logger.info(f"Dataset: {len(self.manifest)} samples (filtered for 5-10s duration)")
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
return len(self.manifest)
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
row = self.manifest.iloc[idx]
|
| 104 |
+
|
| 105 |
+
# Load audio
|
| 106 |
+
import librosa
|
| 107 |
+
waveform, sr = librosa.load(row['file_path'], sr=self.sample_rate)
|
| 108 |
+
|
| 109 |
+
# Pad or truncate to fixed length
|
| 110 |
+
if len(waveform) > self.max_length:
|
| 111 |
+
waveform = waveform[:self.max_length]
|
| 112 |
+
else:
|
| 113 |
+
waveform = np.pad(waveform, (0, self.max_length - len(waveform)))
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
'waveform': torch.FloatTensor(waveform),
|
| 117 |
+
'label': int(row['label']),
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 122 |
+
# Training Functions
|
| 123 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 124 |
+
|
| 125 |
+
def train_epoch(model, dataloader, optimizer, criterion, device):
|
| 126 |
+
"""Train for one epoch."""
|
| 127 |
+
model.train()
|
| 128 |
+
total_loss = 0
|
| 129 |
+
all_preds = []
|
| 130 |
+
all_labels = []
|
| 131 |
+
|
| 132 |
+
for batch in tqdm(dataloader, desc="Training"):
|
| 133 |
+
waveform = batch["waveform"].to(device)
|
| 134 |
+
labels = batch["label"].to(device)
|
| 135 |
+
|
| 136 |
+
optimizer.zero_grad()
|
| 137 |
+
logits = model(waveform)
|
| 138 |
+
loss = criterion(logits, labels)
|
| 139 |
+
|
| 140 |
+
loss.backward()
|
| 141 |
+
optimizer.step()
|
| 142 |
+
|
| 143 |
+
total_loss += loss.item()
|
| 144 |
+
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| 145 |
+
all_preds.extend(preds)
|
| 146 |
+
all_labels.extend(labels.cpu().numpy())
|
| 147 |
+
|
| 148 |
+
avg_loss = total_loss / len(dataloader)
|
| 149 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 150 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
| 151 |
+
|
| 152 |
+
return avg_loss, accuracy, f1
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def validate(model, dataloader, criterion, device):
|
| 156 |
+
"""Validate the model."""
|
| 157 |
+
model.eval()
|
| 158 |
+
total_loss = 0
|
| 159 |
+
all_preds = []
|
| 160 |
+
all_probs = []
|
| 161 |
+
all_labels = []
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for batch in tqdm(dataloader, desc="Validating"):
|
| 165 |
+
waveform = batch["waveform"].to(device)
|
| 166 |
+
labels = batch["label"].to(device)
|
| 167 |
+
|
| 168 |
+
logits = model(waveform)
|
| 169 |
+
loss = criterion(logits, labels)
|
| 170 |
+
|
| 171 |
+
total_loss += loss.item()
|
| 172 |
+
probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
|
| 173 |
+
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
| 174 |
+
|
| 175 |
+
all_preds.extend(preds)
|
| 176 |
+
all_probs.extend(probs)
|
| 177 |
+
all_labels.extend(labels.cpu().numpy())
|
| 178 |
+
|
| 179 |
+
avg_loss = total_loss / len(dataloader)
|
| 180 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 181 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
| 182 |
+
auc = roc_auc_score(all_labels, all_probs)
|
| 183 |
+
|
| 184 |
+
return avg_loss, accuracy, f1, auc
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 188 |
+
# Main Training
|
| 189 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
# Device selection
|
| 193 |
+
if torch.cuda.is_available():
|
| 194 |
+
device = torch.device("cuda")
|
| 195 |
+
elif torch.backends.mps.is_available():
|
| 196 |
+
device = torch.device("mps")
|
| 197 |
+
else:
|
| 198 |
+
device = torch.device("cpu")
|
| 199 |
+
|
| 200 |
+
logger.info(f"🚀 Using device: {device}")
|
| 201 |
+
|
| 202 |
+
# Load datasets
|
| 203 |
+
train_manifest = Path("data/manifests/train.csv")
|
| 204 |
+
val_manifest = Path("data/manifests/val.csv")
|
| 205 |
+
|
| 206 |
+
train_dataset = FastDysarthriaDataset(train_manifest, max_duration=10.0)
|
| 207 |
+
val_dataset = FastDysarthriaDataset(val_manifest, max_duration=10.0)
|
| 208 |
+
|
| 209 |
+
# Use subset for faster training
|
| 210 |
+
MAX_TRAIN_SAMPLES = 500 # Reduced from 3000
|
| 211 |
+
MAX_VAL_SAMPLES = 100 # Reduced from 647
|
| 212 |
+
|
| 213 |
+
if len(train_dataset) > MAX_TRAIN_SAMPLES:
|
| 214 |
+
indices = np.random.choice(len(train_dataset), MAX_TRAIN_SAMPLES, replace=False)
|
| 215 |
+
train_dataset = Subset(train_dataset, indices)
|
| 216 |
+
logger.info(f"✂️ Using subset: {MAX_TRAIN_SAMPLES} training samples")
|
| 217 |
+
|
| 218 |
+
if len(val_dataset) > MAX_VAL_SAMPLES:
|
| 219 |
+
indices = np.random.choice(len(val_dataset), MAX_VAL_SAMPLES, replace=False)
|
| 220 |
+
val_dataset = Subset(val_dataset, indices)
|
| 221 |
+
logger.info(f"✂️ Using subset: {MAX_VAL_SAMPLES} validation samples")
|
| 222 |
+
|
| 223 |
+
# Data loaders
|
| 224 |
+
train_loader = DataLoader(
|
| 225 |
+
train_dataset,
|
| 226 |
+
batch_size=4, # Small batch for speed
|
| 227 |
+
shuffle=True,
|
| 228 |
+
num_workers=0, # Avoid multiprocessing issues
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
val_loader = DataLoader(
|
| 232 |
+
val_dataset,
|
| 233 |
+
batch_size=4,
|
| 234 |
+
shuffle=False,
|
| 235 |
+
num_workers=0,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Model
|
| 239 |
+
model = SimplifiedHuBERTClassifier(freeze_base=True).to(device)
|
| 240 |
+
logger.info(f"✓ Model loaded on {device}")
|
| 241 |
+
|
| 242 |
+
# Optimizer and loss
|
| 243 |
+
optimizer = optim.AdamW(model.classifier.parameters(), lr=1e-3) # Higher LR for frozen base
|
| 244 |
+
criterion = nn.CrossEntropyLoss()
|
| 245 |
+
|
| 246 |
+
# Training loop
|
| 247 |
+
NUM_EPOCHS = 5 # Reduced from 20
|
| 248 |
+
best_val_auc = 0
|
| 249 |
+
best_model_path = Path("models/hubert_fast_best.pt")
|
| 250 |
+
best_model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 251 |
+
|
| 252 |
+
logger.info(f"\n{'='*80}")
|
| 253 |
+
logger.info(f" FAST TRAINING - {NUM_EPOCHS} epochs")
|
| 254 |
+
logger.info(f"{'='*80}\n")
|
| 255 |
+
|
| 256 |
+
for epoch in range(1, NUM_EPOCHS + 1):
|
| 257 |
+
logger.info(f"\nEpoch {epoch}/{NUM_EPOCHS}")
|
| 258 |
+
logger.info("-" * 40)
|
| 259 |
+
|
| 260 |
+
# Train
|
| 261 |
+
train_loss, train_acc, train_f1 = train_epoch(
|
| 262 |
+
model, train_loader, optimizer, criterion, device
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Validate
|
| 266 |
+
val_loss, val_acc, val_f1, val_auc = validate(
|
| 267 |
+
model, val_loader, criterion, device
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Log
|
| 271 |
+
logger.info(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.4f}, F1={train_f1:.4f}")
|
| 272 |
+
logger.info(f"Val: Loss={val_loss:.4f}, Acc={val_acc:.4f}, F1={val_f1:.4f}, AUC={val_auc:.4f}")
|
| 273 |
+
|
| 274 |
+
# Save best model
|
| 275 |
+
if val_auc > best_val_auc:
|
| 276 |
+
best_val_auc = val_auc
|
| 277 |
+
torch.save({
|
| 278 |
+
'epoch': epoch,
|
| 279 |
+
'model_state_dict': model.state_dict(),
|
| 280 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 281 |
+
'val_auc': val_auc,
|
| 282 |
+
}, best_model_path)
|
| 283 |
+
logger.info(f"✓ New best model saved (AUC: {val_auc:.4f})")
|
| 284 |
+
|
| 285 |
+
logger.info(f"\n{'='*80}")
|
| 286 |
+
logger.info(f" ✓ TRAINING COMPLETE!")
|
| 287 |
+
logger.info(f"{'='*80}")
|
| 288 |
+
logger.info(f"Best validation AUC: {best_val_auc:.4f}")
|
| 289 |
+
logger.info(f"Model saved to: {best_model_path}")
|
| 290 |
+
logger.info(f"\nNext steps:")
|
| 291 |
+
logger.info(f" 1. Test the model on test set")
|
| 292 |
+
logger.info(f" 2. Update model_registry.py to use this checkpoint")
|
| 293 |
+
logger.info(f" 3. Run inference on new audio files")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
main()
|
training/train_hubert_salr.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train HuBERT-SALR model for dysarthria detection."""
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import mlflow
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
from training.dataset import DysarthriaDataset, collate_fn
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HuBERTSALRModel(nn.Module):
|
| 19 |
+
"""HuBERT with SALR head for dysarthria detection."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, hubert_checkpoint="facebook/hubert-large-ll60k"):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
from transformers import HubertModel
|
| 25 |
+
|
| 26 |
+
# Load pretrained HuBERT
|
| 27 |
+
self.hubert = HubertModel.from_pretrained(hubert_checkpoint)
|
| 28 |
+
|
| 29 |
+
# Freeze feature extractor (optional)
|
| 30 |
+
for param in self.hubert.feature_extractor.parameters():
|
| 31 |
+
param.requires_grad = False
|
| 32 |
+
|
| 33 |
+
# Layer-weighted pooling (learnable weights for 24 layers)
|
| 34 |
+
self.layer_weights = nn.Parameter(torch.ones(24) / 24)
|
| 35 |
+
|
| 36 |
+
# SALR head
|
| 37 |
+
self.classifier = nn.Sequential(
|
| 38 |
+
nn.Linear(1024, 256),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Dropout(0.3),
|
| 41 |
+
nn.Linear(256, 2), # Binary classification
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.embedder = nn.Sequential(
|
| 45 |
+
nn.Linear(1024, 256),
|
| 46 |
+
nn.ReLU(),
|
| 47 |
+
nn.Dropout(0.3),
|
| 48 |
+
nn.Linear(256, 128), # Embedding for triplet loss
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, waveform):
|
| 52 |
+
"""Forward pass."""
|
| 53 |
+
# HuBERT encoding
|
| 54 |
+
outputs = self.hubert(waveform, output_hidden_states=True)
|
| 55 |
+
hidden_states = outputs.hidden_states # (batch, seq_len, hidden_size) × 24 layers
|
| 56 |
+
|
| 57 |
+
# Layer-weighted pooling
|
| 58 |
+
weighted_hidden = torch.stack(
|
| 59 |
+
[self.layer_weights[i] * hidden_states[i] for i in range(24)],
|
| 60 |
+
dim=0
|
| 61 |
+
).sum(dim=0) # (batch, seq_len, 1024)
|
| 62 |
+
|
| 63 |
+
# Global average pooling
|
| 64 |
+
pooled = weighted_hidden.mean(dim=1) # (batch, 1024)
|
| 65 |
+
|
| 66 |
+
# Classification logits
|
| 67 |
+
logits = self.classifier(pooled)
|
| 68 |
+
|
| 69 |
+
# Embeddings for triplet loss
|
| 70 |
+
embeddings = self.embedder(pooled)
|
| 71 |
+
embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
|
| 72 |
+
|
| 73 |
+
return logits, embeddings
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def train_hubert_salr(
|
| 77 |
+
train_manifest="data/manifests/train_manifest.csv",
|
| 78 |
+
val_manifest="data/manifests/val_manifest.csv",
|
| 79 |
+
batch_size=8,
|
| 80 |
+
num_epochs=50,
|
| 81 |
+
learning_rate=1e-4,
|
| 82 |
+
device="cuda",
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Train HuBERT-SALR model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
train_manifest: Path to training manifest
|
| 89 |
+
val_manifest: Path to validation manifest
|
| 90 |
+
batch_size: Batch size
|
| 91 |
+
num_epochs: Number of epochs
|
| 92 |
+
learning_rate: Learning rate
|
| 93 |
+
device: Device (cuda/cpu)
|
| 94 |
+
"""
|
| 95 |
+
# Set device
|
| 96 |
+
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 97 |
+
logger.info(f"Using device: {device}")
|
| 98 |
+
|
| 99 |
+
# Initialize MLflow
|
| 100 |
+
mlflow.set_experiment("dysarthria_hubert_salr")
|
| 101 |
+
|
| 102 |
+
with mlflow.start_run():
|
| 103 |
+
# Log parameters
|
| 104 |
+
mlflow.log_params({
|
| 105 |
+
"model": "HuBERT-SALR",
|
| 106 |
+
"batch_size": batch_size,
|
| 107 |
+
"num_epochs": num_epochs,
|
| 108 |
+
"learning_rate": learning_rate,
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# Create datasets
|
| 112 |
+
train_dataset = DysarthriaDataset(train_manifest, augment=True)
|
| 113 |
+
val_dataset = DysarthriaDataset(val_manifest, augment=False)
|
| 114 |
+
|
| 115 |
+
train_loader = DataLoader(
|
| 116 |
+
train_dataset,
|
| 117 |
+
batch_size=batch_size,
|
| 118 |
+
shuffle=True,
|
| 119 |
+
num_workers=0, # Disabled for compatibility
|
| 120 |
+
collate_fn=collate_fn,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
val_loader = DataLoader(
|
| 124 |
+
val_dataset,
|
| 125 |
+
batch_size=batch_size,
|
| 126 |
+
shuffle=False,
|
| 127 |
+
num_workers=0, # Disabled for compatibility
|
| 128 |
+
collate_fn=collate_fn,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Initialize model
|
| 132 |
+
model = HuBERTSALRModel().to(device)
|
| 133 |
+
|
| 134 |
+
# Optimizer
|
| 135 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 136 |
+
|
| 137 |
+
# Losses
|
| 138 |
+
ce_loss_fn = nn.CrossEntropyLoss()
|
| 139 |
+
triplet_loss_fn = nn.TripletMarginLoss(margin=1.0)
|
| 140 |
+
|
| 141 |
+
# Training loop
|
| 142 |
+
best_val_loss = float("inf")
|
| 143 |
+
|
| 144 |
+
for epoch in range(num_epochs):
|
| 145 |
+
# Training
|
| 146 |
+
model.train()
|
| 147 |
+
train_loss = 0.0
|
| 148 |
+
|
| 149 |
+
for batch in train_loader:
|
| 150 |
+
waveform = batch["waveform"].to(device)
|
| 151 |
+
labels = batch["label"].squeeze(1).to(device)
|
| 152 |
+
|
| 153 |
+
optimizer.zero_grad()
|
| 154 |
+
|
| 155 |
+
# Forward pass
|
| 156 |
+
logits, embeddings = model(waveform)
|
| 157 |
+
|
| 158 |
+
# Classification loss
|
| 159 |
+
ce_loss = ce_loss_fn(logits, labels)
|
| 160 |
+
|
| 161 |
+
# Triplet loss (simplified: use random triplets)
|
| 162 |
+
# In full implementation, use hard negative mining
|
| 163 |
+
triplet_loss = torch.tensor(0.0).to(device) # Placeholder
|
| 164 |
+
|
| 165 |
+
# Combined loss
|
| 166 |
+
loss = ce_loss + 0.5 * triplet_loss
|
| 167 |
+
|
| 168 |
+
# Backward pass
|
| 169 |
+
loss.backward()
|
| 170 |
+
optimizer.step()
|
| 171 |
+
|
| 172 |
+
train_loss += loss.item()
|
| 173 |
+
|
| 174 |
+
train_loss /= len(train_loader)
|
| 175 |
+
|
| 176 |
+
# Validation
|
| 177 |
+
model.eval()
|
| 178 |
+
val_loss = 0.0
|
| 179 |
+
correct = 0
|
| 180 |
+
total = 0
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
for batch in val_loader:
|
| 184 |
+
waveform = batch["waveform"].to(device)
|
| 185 |
+
labels = batch["label"].squeeze(1).to(device)
|
| 186 |
+
|
| 187 |
+
logits, _ = model(waveform)
|
| 188 |
+
loss = ce_loss_fn(logits, labels)
|
| 189 |
+
|
| 190 |
+
val_loss += loss.item()
|
| 191 |
+
|
| 192 |
+
preds = logits.argmax(dim=1)
|
| 193 |
+
correct += (preds == labels).sum().item()
|
| 194 |
+
total += labels.size(0)
|
| 195 |
+
|
| 196 |
+
val_loss /= len(val_loader)
|
| 197 |
+
val_acc = correct / total
|
| 198 |
+
|
| 199 |
+
# Log metrics
|
| 200 |
+
mlflow.log_metrics({
|
| 201 |
+
"train_loss": train_loss,
|
| 202 |
+
"val_loss": val_loss,
|
| 203 |
+
"val_accuracy": val_acc,
|
| 204 |
+
}, step=epoch)
|
| 205 |
+
|
| 206 |
+
logger.info(
|
| 207 |
+
f"Epoch {epoch+1}/{num_epochs}: "
|
| 208 |
+
f"train_loss={train_loss:.4f}, "
|
| 209 |
+
f"val_loss={val_loss:.4f}, "
|
| 210 |
+
f"val_acc={val_acc:.4f}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Save best model
|
| 214 |
+
if val_loss < best_val_loss:
|
| 215 |
+
best_val_loss = val_loss
|
| 216 |
+
checkpoint_path = Path("models/checkpoints/hubert_salr_best.pt")
|
| 217 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
torch.save(model.state_dict(), checkpoint_path)
|
| 219 |
+
mlflow.log_artifact(str(checkpoint_path))
|
| 220 |
+
|
| 221 |
+
logger.info("Training complete!")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
train_hubert_salr()
|