bbmb commited on
Commit
2cb85b0
Β·
1 Parent(s): 8475d85

Fix feature extraction and models to match notebook

Browse files
Files changed (2) hide show
  1. ml/feature_extraction.py +202 -144
  2. routers/training.py +446 -119
ml/feature_extraction.py CHANGED
@@ -1,162 +1,220 @@
1
  """
2
- Feature extraction β€” 82-dimensional feature vector per hit.
3
-
4
- Breakdown:
5
- [0:50] Welch PSD β€” 50 log-spaced frequency bins (relative, sum=1)
6
- [50:63] MFCC mean β€” 13 coefficients (mean-subtracted per coefficient)
7
- [63:76] MFCC std β€” 13 coefficients
8
- [76] Decay time constant Ο„ (exponential fit)
9
- [77] Energy ratio (late/early)
10
- [78] Crest factor (peak/RMS)
11
- [79] Log peak amplitude
12
- [80:82] Spectral centroid mean + std
13
  """
14
 
15
  import warnings
16
  import numpy as np
17
  import librosa
18
  from scipy.signal import welch
19
- from scipy.optimize import curve_fit
20
 
21
- from config import SR, N_MFCC, N_FFT, HIT_WINDOW_LEN, N_MELS, HOP_LENGTH_MEL, SPEC_TIME_FRAMES
22
 
23
- # ─── PSD ──────────────────────────────────────────────────────────────────────
 
 
 
 
 
24
 
25
- def extract_psd(window: np.ndarray, n_bins: int = 50) -> np.ndarray:
26
- """
27
- Welch PSD normalised to sum=1 (relative PSD), log-spaced frequency bins.
28
- Normalisation makes features robust to recording-level differences.
29
- """
30
- f, pxx = welch(window, fs=SR, nperseg=512, noverlap=256)
31
- # Log-spaced bin edges from 50 Hz to Nyquist
32
- edges = np.logspace(np.log10(50), np.log10(SR / 2), n_bins + 1)
33
- binned = np.zeros(n_bins, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  for i in range(n_bins):
35
- mask = (f >= edges[i]) & (f < edges[i + 1])
36
- if mask.any():
37
- binned[i] = float(pxx[mask].mean())
38
- total = binned.sum()
39
- if total > 0:
40
- binned /= total
41
- return binned
42
-
43
-
44
- # ─── MFCC ─────────────────────────────────────────────────────────────────────
45
-
46
- def extract_mfcc(window: np.ndarray, n_mfcc: int = N_MFCC) -> tuple[np.ndarray, np.ndarray]:
47
- """
48
- Returns (mfcc_mean, mfcc_std) β€” both shape (n_mfcc,).
49
- Mean subtraction (per coefficient across the window) is applied first β€”
50
- this is the single most impactful normalisation for cross-session robustness.
51
- """
52
  with warnings.catch_warnings():
53
  warnings.simplefilter("ignore")
54
- mfcc = librosa.feature.mfcc(y=window, sr=SR, n_mfcc=n_mfcc, n_fft=N_FFT)
55
- # Per-coefficient mean subtraction
56
- mfcc -= mfcc.mean(axis=1, keepdims=True)
57
- return mfcc.mean(axis=1).astype(np.float32), mfcc.std(axis=1).astype(np.float32)
58
-
59
-
60
- # ─── Decay ────────────────────────────────────────────────────────────────────
61
-
62
- def _exp_model(t, A, tau):
63
- return A * np.exp(-t / tau)
64
-
65
-
66
- def extract_decay(window: np.ndarray) -> float:
67
- """
68
- Fit AΒ·exp(-t/Ο„) to the RMS energy envelope.
69
- Returns Ο„ (seconds). Larger Ο„ = slower decay = tighter flange.
70
- Falls back to 0.05 on fit failure.
71
- """
72
- # RMS in 5ms frames
73
- frame_len = int(0.005 * SR)
74
- n_frames = len(window) // frame_len
75
- rms_env = np.array([
76
- np.sqrt(np.mean(window[i * frame_len:(i + 1) * frame_len] ** 2))
77
- for i in range(n_frames)
78
- ], dtype=np.float32)
79
- t = np.arange(n_frames) * 0.005
80
- # Trim to non-zero region
81
- thresh = 0.02 * rms_env.max()
82
- valid = rms_env > thresh
83
- if valid.sum() < 5:
84
- return 0.05
85
- try:
86
- p0 = (rms_env[valid].max(), 0.15)
87
- popt, _ = curve_fit(_exp_model, t[valid], rms_env[valid], p0=p0,
88
- bounds=([0, 0.001], [np.inf, 2.0]), maxfev=2000)
89
- tau = float(np.clip(popt[1], 0.001, 2.0))
90
- except Exception:
91
- tau = 0.05
92
- return tau
93
-
94
-
95
- # ─── Energy ratio ─────────────────────────────────────────────────────────────
96
-
97
- def extract_energy_ratio(window: np.ndarray, split_ms: float = 50.0) -> float:
98
- """
99
- E_late / E_early where split is at split_ms ms after the hit onset.
100
- High ratio β†’ flange still ringing β†’ tight.
101
- """
102
- split_n = int(split_ms / 1000 * SR)
103
- early = float(np.sum(window[:split_n] ** 2)) + 1e-12
104
- late = float(np.sum(window[split_n:] ** 2))
105
- return min(late / early, 100.0) # cap for numerical safety
106
-
107
-
108
- # ─── Spectral centroid ────────────────────────────────────────────────────────
109
-
110
- def extract_spectral_centroid(window: np.ndarray) -> tuple[float, float]:
111
- cent = librosa.feature.spectral_centroid(y=window, sr=SR, n_fft=N_FFT)[0]
112
- return float(cent.mean()), float(cent.std())
113
-
114
-
115
- # ─── Full 82-dim vector ───────────────────────────────────────────────────────
116
-
117
- def extract_features(window: np.ndarray) -> np.ndarray:
118
- """Return 82-dim feature vector for a single hit window."""
119
- psd = extract_psd(window) # 50
120
- mfcc_m, mfcc_s = extract_mfcc(window) # 13 + 13
121
- tau = extract_decay(window) # 1
122
- energy_ratio = extract_energy_ratio(window) # 1
123
- peak_amp = float(np.abs(window).max())
124
- rms = float(np.sqrt(np.mean(window ** 2))) + 1e-12
125
- crest = float(peak_amp / rms) # 1
126
- log_peak = float(np.log1p(peak_amp)) # 1
127
- sc_mean, sc_std = extract_spectral_centroid(window) # 2
128
-
129
- feat = np.concatenate([
130
- psd,
131
- mfcc_m,
132
- mfcc_s,
133
- [tau, energy_ratio, crest, log_peak, sc_mean, sc_std],
134
  ]).astype(np.float32)
135
- return feat
136
-
137
 
138
- FEATURE_NAMES: list[str] = (
139
- [f"psd_{i}" for i in range(50)] +
140
- [f"mfcc_mean_{i}" for i in range(13)] +
141
- [f"mfcc_std_{i}" for i in range(13)] +
142
- ["tau", "energy_ratio", "crest_factor", "log_peak_amp", "sc_mean", "sc_std"]
143
- )
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- # ─── Mel spectrogram (CNN input) ─────────────────────────────────────────────
147
-
148
- def extract_mel_spectrogram(window: np.ndarray) -> np.ndarray:
149
- """
150
- Returns mel spectrogram of shape (N_MELS, SPEC_TIME_FRAMES) in dB.
151
- Used as CNN input (add channel dim in model).
152
- """
153
- mel = librosa.feature.melspectrogram(
154
- y=window, sr=SR, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH_MEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
- mel_db = librosa.power_to_db(mel, ref=np.max)
157
- # Pad / trim to fixed width
158
- if mel_db.shape[1] < SPEC_TIME_FRAMES:
159
- mel_db = np.pad(mel_db, ((0, 0), (0, SPEC_TIME_FRAMES - mel_db.shape[1])))
160
- else:
161
- mel_db = mel_db[:, :SPEC_TIME_FRAMES]
162
- return mel_db.astype(np.float32)
 
 
 
 
 
 
 
 
1
  """
2
+ Feature extraction β€” 82-dimensional physics-informed vector.
3
+ EXACTLY matches Cell 4 of final_project_saurav_silwal.ipynb.
4
+
5
+ Group 1: Relative PSD in 20 log-spaced bins (50 Hz – 8 kHz) β†’ 20 dims
6
+ Group 2: MFCC mean/std + delta MFCC mean/std (13 coeffs each) β†’ 52 dims
7
+ Group 3: Physics features (centroid, bandwidth, rolloff, ZCR,
8
+ peak freq, decay Ο„, energy ratio, RMS, Q-factor) β†’ 10 dims
9
+ Total β†’ 82 dims
 
 
 
10
  """
11
 
12
  import warnings
13
  import numpy as np
14
  import librosa
15
  from scipy.signal import welch
 
16
 
17
+ from config import SR, N_MELS, N_FFT, HOP_LENGTH_MEL, SPEC_TIME_FRAMES
18
 
19
+ # ── Constants matching notebook ────────────────────────────────────────────
20
+ N_PSD_BINS = 20
21
+ PSD_FMIN = 50.0
22
+ PSD_FMAX = 8000.0
23
+ WELCH_NPERSEG = 2048
24
+ WELCH_NOVERLAP = 1024
25
 
26
+ N_MFCC = 13
27
+ MFCC_NFFT = 2048
28
+ MFCC_HOP = 512
29
+
30
+ DECAY_FIT_MS = 200
31
+ EARLY_LATE_FRAC = 0.20
32
+ PEAK_SAMPLE_IN_WIN = int(0.020 * SR) # 960 samples = 20 ms pre-peak
33
+
34
+ FMIN_MEL = 0
35
+ FMAX_MEL = SR // 2 # Nyquist = 24 000 Hz
36
+
37
+
38
+ # ── Group 1: Relative PSD ─────────────────────────────────────────────────
39
+
40
+ def relative_psd_log_bins(y, sr=SR, n_bins=N_PSD_BINS,
41
+ f_min=PSD_FMIN, f_max=PSD_FMAX):
42
+ """Welch PSD β†’ 20 log-spaced bins β†’ normalized so sum=1."""
43
+ f, pxx = welch(y, fs=sr,
44
+ nperseg=min(WELCH_NPERSEG, len(y)),
45
+ noverlap=min(WELCH_NOVERLAP, len(y) // 2))
46
+ edges = np.logspace(np.log10(f_min), np.log10(f_max), n_bins + 1)
47
+ bins = np.zeros(n_bins, dtype=np.float32)
48
  for i in range(n_bins):
49
+ mask = (f >= edges[i]) & (f < edges[i + 1])
50
+ bins[i] = pxx[mask].sum()
51
+ total = bins.sum()
52
+ if total > 1e-20:
53
+ bins /= total
54
+ return bins, f, pxx
55
+
56
+
57
+ # ── Group 2: MFCC + delta statistics ──────────────────────────────────────
58
+
59
+ def mfcc_stats(y, sr=SR):
60
+ """13 MFCCs β†’ mean+std (26) + delta mean+std (26) = 52 dims."""
 
 
 
 
 
61
  with warnings.catch_warnings():
62
  warnings.simplefilter("ignore")
63
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC,
64
+ n_fft=MFCC_NFFT, hop_length=MFCC_HOP)
65
+ delta = librosa.feature.delta(mfcc)
66
+ return np.concatenate([
67
+ mfcc.mean(axis=1), mfcc.std(axis=1),
68
+ delta.mean(axis=1), delta.std(axis=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ]).astype(np.float32)
 
 
70
 
 
 
 
 
 
 
71
 
72
+ # ── Group 3: Physics features ──────────────────────────────────────────────
73
+
74
+ def peak_frequency(f, pxx):
75
+ if pxx.max() <= 0:
76
+ return 0.0
77
+ return float(f[np.argmax(pxx)])
78
+
79
+
80
+ def q_factor(f, pxx):
81
+ """Q = f_peak / -3 dB bandwidth. High Q = tight (rings cleanly)."""
82
+ if pxx.max() <= 0:
83
+ return 0.0
84
+ pdb = 10 * np.log10(pxx + 1e-20)
85
+ peak_idx = int(np.argmax(pdb))
86
+ threshold = pdb[peak_idx] - 3.0
87
+ L = peak_idx
88
+ while L > 0 and pdb[L] >= threshold:
89
+ L -= 1
90
+ R = peak_idx
91
+ while R < len(pdb) - 1 and pdb[R] >= threshold:
92
+ R += 1
93
+ bw = max(f[R] - f[L], 1.0)
94
+ return float(f[peak_idx] / bw)
95
+
96
+
97
+ def decay_tau(y, peak_sample=PEAK_SAMPLE_IN_WIN, sr=SR, fit_ms=DECAY_FIT_MS):
98
+ """Decay time constant Ο„. Loose β†’ small Ο„. Tight β†’ large Ο„."""
99
+ n_fit = int(fit_ms * sr / 1000)
100
+ seg = y[peak_sample:min(peak_sample + n_fit, len(y))]
101
+ if len(seg) < 100:
102
+ return np.nan
103
+ env_w = max(1, int(0.005 * sr))
104
+ env = np.convolve(np.abs(seg), np.ones(env_w) / env_w, mode='same')
105
+ if env.max() < 1e-8:
106
+ return np.nan
107
+ active = np.where(env > 0.05 * env.max())[0]
108
+ if len(active) < 50:
109
+ return np.nan
110
+ n_active = active[-1] + 1
111
+ eps = env.max() * 1e-4
112
+ log_env = np.log(env[:n_active] + eps)
113
+ t = np.arange(n_active) / sr
114
+ slope, _ = np.polyfit(t, log_env, 1)
115
+ if slope >= 0:
116
+ return np.nan
117
+ tau = -1.0 / slope
118
+ return float(tau) if 0.001 <= tau <= 10.0 else np.nan
119
+
120
+
121
+ def energy_ratio(y, frac=EARLY_LATE_FRAC):
122
+ """E_late / E_early. Tight flanges still ringing β†’ high ratio."""
123
+ n_chunk = int(frac * len(y))
124
+ e_early = np.sqrt(np.mean(y[:n_chunk] ** 2))
125
+ e_late = np.sqrt(np.mean(y[-n_chunk:] ** 2))
126
+ return float(e_late / (e_early + 1e-12))
127
+
128
+
129
+ # ── Master 82-dim extractor ───────────────────────────────────────────────
130
+
131
+ def extract_features(y: np.ndarray, sr: int = SR) -> np.ndarray:
132
+ """Return 82-dim feature vector for one hit window."""
133
+ psd_bins, f_psd, pxx = relative_psd_log_bins(y, sr) # 20
134
+
135
+ cepstral = mfcc_stats(y, sr) # 52
136
 
137
+ with warnings.catch_warnings():
138
+ warnings.simplefilter("ignore")
139
+ sc = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
140
+ sb = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
141
+ sr85 = librosa.feature.spectral_rolloff(y=y, sr=sr, roll_percent=0.85)[0]
142
+ zcr = librosa.feature.zero_crossing_rate(y)[0]
143
+
144
+ physics = np.array([
145
+ sc.mean(), sc.std(), # 2: centroid mean/std
146
+ sb.mean(), # 1: bandwidth mean
147
+ sr85.mean(), # 1: rolloff 85%
148
+ zcr.mean(), # 1: zero-crossing rate
149
+ peak_frequency(f_psd, pxx), # 1: dominant freq
150
+ decay_tau(y), # 1: Ο„ (NaN β†’ imputed after)
151
+ energy_ratio(y), # 1: E_late / E_early
152
+ float(np.sqrt(np.mean(y ** 2))),# 1: RMS energy
153
+ q_factor(f_psd, pxx), # 1: Q-factor
154
+ ], dtype=np.float32) # 10 total
155
+
156
+ return np.concatenate([psd_bins, cepstral, physics])
157
+
158
+
159
+ def impute_nans(X: np.ndarray, y_labels: np.ndarray, n_classes: int = 3) -> np.ndarray:
160
+ """Per-class median imputation for NaN columns (tau can be NaN)."""
161
+ X = X.copy()
162
+ nan_cols = np.where(np.isnan(X).any(axis=0))[0]
163
+ for c in nan_cols:
164
+ for cls in range(n_classes):
165
+ cls_mask = (y_labels == cls)
166
+ median_val = float(np.nanmedian(X[cls_mask, c]))
167
+ if np.isnan(median_val):
168
+ median_val = float(np.nanmedian(X[:, c]))
169
+ fill_mask = cls_mask & np.isnan(X[:, c])
170
+ X[fill_mask, c] = median_val
171
+ return X
172
+
173
+
174
+ # ── Feature name list ─────────────────────────────────────────────────────
175
+
176
+ def _build_feature_names() -> list[str]:
177
+ names = []
178
+ edges = np.logspace(np.log10(PSD_FMIN), np.log10(PSD_FMAX), N_PSD_BINS + 1)
179
+ for i in range(N_PSD_BINS):
180
+ names.append(f'psd_{edges[i]:.0f}_{edges[i+1]:.0f}Hz')
181
+ names += [f'mfcc{i:02d}_mean' for i in range(N_MFCC)]
182
+ names += [f'mfcc{i:02d}_std' for i in range(N_MFCC)]
183
+ names += [f'dmfcc{i:02d}_mean' for i in range(N_MFCC)]
184
+ names += [f'dmfcc{i:02d}_std' for i in range(N_MFCC)]
185
+ names += ['spec_centroid_mean', 'spec_centroid_std',
186
+ 'spec_bandwidth_mean', 'spec_rolloff85_mean',
187
+ 'zero_cross_rate_mean', 'peak_frequency',
188
+ 'decay_tau', 'energy_ratio', 'rms_energy', 'q_factor']
189
+ return names
190
+
191
+ FEATURE_NAMES: list[str] = _build_feature_names()
192
+ assert len(FEATURE_NAMES) == 82
193
+
194
+
195
+ # ── Mel spectrogram (CNN / BiLSTM input) β€” matches notebook Cell 5 ────────
196
+
197
+ def extract_mel_spectrogram(y: np.ndarray, sr: int = SR,
198
+ n_mels: int = N_MELS,
199
+ n_fft: int = N_FFT,
200
+ hop_length: int = HOP_LENGTH_MEL,
201
+ target_frames: int = SPEC_TIME_FRAMES) -> np.ndarray:
202
+ """One hit β†’ standardized log-mel spectrogram of shape (n_mels, target_frames)."""
203
+ mel = librosa.feature.melspectrogram(
204
+ y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
205
+ fmin=FMIN_MEL, fmax=FMAX_MEL, power=2.0,
206
  )
207
+ mel_db = librosa.power_to_db(mel, ref=np.max).astype(np.float32)
208
+
209
+ n_frames = mel_db.shape[1]
210
+ if n_frames < target_frames:
211
+ pad_val = float(mel_db.min())
212
+ mel_db = np.pad(mel_db, ((0, 0), (0, target_frames - n_frames)),
213
+ mode='constant', constant_values=pad_val)
214
+ elif n_frames > target_frames:
215
+ start = (n_frames - target_frames) // 2
216
+ mel_db = mel_db[:, start:start + target_frames]
217
+
218
+ # per-sample standardize
219
+ mu, sigma = mel_db.mean(), mel_db.std()
220
+ return (mel_db - mu) / (sigma + 1e-6) # shape (64, 128)
routers/training.py CHANGED
@@ -1,15 +1,18 @@
1
  """
2
  Router: POST /api/train β€” launch training as a background task
3
  GET /api/results β€” fetch completed results
4
- WS /ws/train/{task_id} β€” stream live epoch metrics
5
 
6
- Models trained: SVM, LR, RF, MLP, KNN (shallow) + CNN, LSTM (deep via Keras).
7
- Each model runs LOIO cross-validation (Task 2) + 70/30 split (Task 1).
 
 
 
 
8
  """
9
 
10
  import asyncio
11
  import uuid
12
- import threading
13
  import traceback
14
  from concurrent.futures import ThreadPoolExecutor
15
 
@@ -19,123 +22,225 @@ from sklearn.preprocessing import StandardScaler
19
  from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
20
  from sklearn.svm import SVC
21
  from sklearn.linear_model import LogisticRegression
22
- from sklearn.ensemble import RandomForestClassifier
23
- from sklearn.neural_network import MLPClassifier
24
  from sklearn.neighbors import KNeighborsClassifier
25
  from fastapi import APIRouter, Header, HTTPException, WebSocket, WebSocketDisconnect
26
 
27
  from session import session_manager
28
  from ws.training_ws import ws_manager, emit
 
 
 
29
  from config import SEED, TEST_SIZE, IDX_TO_CLASS, CLASS_NAMES, N_CLASSES
30
 
31
- router = APIRouter(tags=["training"])
32
-
33
- # Shared thread pool for background training
34
  _executor = ThreadPoolExecutor(max_workers=2)
35
 
36
 
37
- # ─── Model definitions ────────────────────────────────────────────────────────
 
 
38
 
39
- SHALLOW_MODELS = {
40
  "SVM": lambda: SVC(
41
  kernel="rbf", C=10.0, gamma="scale",
42
- probability=True, class_weight="balanced", random_state=SEED
43
  ),
44
  "LR": lambda: LogisticRegression(
45
  C=1.0, max_iter=2000, class_weight="balanced",
46
- multi_class="multinomial", solver="lbfgs", random_state=SEED
47
- ),
48
- "RF": lambda: RandomForestClassifier(
49
- n_estimators=200, max_depth=None,
50
- class_weight="balanced", random_state=SEED, n_jobs=-1
51
  ),
52
- "MLP": lambda: MLPClassifier(
53
- hidden_layer_sizes=(128, 64), activation="relu",
54
- max_iter=500, early_stopping=True, random_state=SEED
55
  ),
56
- "KNN": lambda: KNeighborsClassifier(n_neighbors=5, metric="euclidean"),
57
  }
58
 
59
 
60
- # ─── Training worker (runs in thread) ─────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def _train_shallow(
63
- task_id: str,
64
- model_name: str,
65
- X: np.ndarray,
66
- y: np.ndarray,
67
- groups: np.ndarray,
68
- loop: asyncio.AbstractEventLoop,
69
- queue: asyncio.Queue,
70
- session,
71
- ):
 
 
 
 
72
  try:
73
  scaler = StandardScaler()
74
- X_s = scaler.fit_transform(X)
75
 
76
- # ── Task 1: Dependent 70/30 split ──
77
  X_tr, X_te, y_tr, y_te = train_test_split(
78
  X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
79
  )
80
- clf_t1 = SHALLOW_MODELS[model_name]()
81
- clf_t1.fit(X_tr, y_tr)
82
- y_pred_t1 = clf_t1.predict(X_te)
83
- acc_t1 = float(accuracy_score(y_te, y_pred_t1))
84
- f1_t1 = float(f1_score(y_te, y_pred_t1, average="macro"))
85
- cm_t1 = confusion_matrix(y_te, y_pred_t1, labels=[0, 1, 2]).tolist()
86
-
87
- emit(loop, queue, {
88
- "type": "task1_done", "model": model_name,
89
- "acc": round(acc_t1, 4), "f1": round(f1_t1, 4),
90
- })
91
-
92
- # ── Task 2: LOIO cross-validation ──
93
- logo = LeaveOneGroupOut()
94
- fold_accs: list[float] = []
95
- fold_records: list[dict] = []
96
 
97
  for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
98
  flange_out = int(groups[te_idx[0]])
99
- clf = SHALLOW_MODELS[model_name]()
100
- clf.fit(X_s[tr_idx], y[tr_idx])
101
- y_p = clf.predict(X_s[te_idx])
102
- y_pr = clf.predict_proba(X_s[te_idx]) if hasattr(clf, "predict_proba") else None
103
- acc_f = float(accuracy_score(y[te_idx], y_p))
104
  fold_accs.append(acc_f)
105
- fold_records.append({
106
- "fold": fold_i + 1,
107
- "flange_out": flange_out,
108
- "acc": round(acc_f, 4),
109
- "n_test": len(te_idx),
110
- })
111
- emit(loop, queue, {
112
- "type": "fold_done",
113
- "model": model_name,
114
- "fold": fold_i + 1,
115
- "flange_out": flange_out,
116
- "acc": round(acc_f, 4),
117
- })
118
-
119
- # Final model on all data (for ensemble / CORAL)
120
- clf_final = SHALLOW_MODELS[model_name]()
121
  clf_final.fit(X_s, y)
122
- y_pred_all = clf_final.predict(X_s)
123
- train_acc = float(accuracy_score(y, y_pred_all))
124
- cm_loio_pooled = confusion_matrix(
125
- [f["flange_out"] for f in fold_records], # dummy β€” use actual pooled
126
- [f["flange_out"] for f in fold_records],
127
- ).tolist()
128
-
129
- # Pooled LOIO confusion matrix
130
- all_y_true: list[int] = []
131
- all_y_pred: list[int] = []
132
- for tr_idx, te_idx in logo.split(X_s, y, groups):
133
- clf = SHALLOW_MODELS[model_name]()
134
- clf.fit(X_s[tr_idx], y[tr_idx])
135
- all_y_true.extend(y[te_idx].tolist())
136
- all_y_pred.extend(clf.predict(X_s[te_idx]).tolist())
137
- cm_t2 = confusion_matrix(all_y_true, all_y_pred, labels=[0, 1, 2]).tolist()
138
- f1_t2 = float(f1_score(all_y_true, all_y_pred, average="macro"))
139
 
140
  result = {
141
  "model": model_name,
@@ -143,17 +248,90 @@ def _train_shallow(
143
  "task1_f1": round(f1_t1, 4),
144
  "task1_cm": cm_t1,
145
  "task2_mean": round(float(np.mean(fold_accs)), 4),
146
- "task2_std": round(float(np.std(fold_accs)), 4),
147
  "task2_f1": round(f1_t2, 4),
148
  "task2_cm": cm_t2,
149
  "folds": fold_records,
150
  "train_acc": round(train_acc, 4),
151
- "scaler_mean": scaler.mean_.tolist(),
152
- "scaler_scale": scaler.scale_.tolist(),
153
  }
154
  session.training_results[model_name] = result
155
  session.touch()
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  emit(loop, queue, {"type": "model_done", "model": model_name, **result})
158
 
159
  except Exception as e:
@@ -161,15 +339,167 @@ def _train_shallow(
161
  traceback.print_exc()
162
 
163
 
164
- def _train_all_models(task_id: str, session_id: str, models: list[str]):
165
- """Entry point for background thread: trains all requested models sequentially."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  loop = asyncio.new_event_loop()
167
  asyncio.set_event_loop(loop)
168
 
169
  session = session_manager.get(session_id)
170
  if session is None:
171
  return
172
-
173
  queue = ws_manager.get_queue(task_id)
174
  if queue is None:
175
  return
@@ -179,57 +509,57 @@ def _train_all_models(task_id: str, session_id: str, models: list[str]):
179
  emit(loop, queue, {"type": "error", "message": "Features not extracted yet"})
180
  return
181
 
182
- X = np.array(feats["X_feat"], dtype=np.float32)
183
- y = np.array(feats["labels"], dtype=np.int64)
184
- groups = np.array(feats["flange_groups"], dtype=np.int64)
185
 
186
- for model_name in models:
187
- if model_name in SHALLOW_MODELS:
188
- _train_shallow(task_id, model_name, X, y, groups, loop, queue, session)
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  emit(loop, queue, {"type": "all_done", "task_id": task_id})
191
 
192
 
193
- # ─── Routes ──────────────────────────────────────────────────────────────────
 
 
194
 
195
  @router.post("/api/train")
196
  async def start_training(
197
  session_id: str = Header(..., alias="X-Session-Id"),
198
  body: dict = None,
199
  ):
200
- """
201
- Launch background training. Returns task_id for WebSocket connection.
202
- Body: {"models": ["SVM", "LR", "RF", "MLP", "KNN"]}
203
- """
204
  session = session_manager.get(session_id)
205
  if session is None:
206
  raise HTTPException(status_code=404, detail="Session not found")
207
  if not session.features:
208
  raise HTTPException(status_code=400, detail="Extract features first: POST /api/features")
209
 
210
- models = (body or {}).get("models", list(SHALLOW_MODELS.keys()))
211
-
212
  task_id = str(uuid.uuid4())
213
- # Create queue before starting thread (thread will use it immediately)
214
  ws_manager.create_queue(task_id)
215
  session.training_tasks[task_id] = models
216
  session.touch()
217
 
218
  loop = asyncio.get_event_loop()
219
- loop.run_in_executor(
220
- _executor,
221
- _train_all_models,
222
- task_id,
223
- session_id,
224
- models,
225
- )
226
 
227
  return {"task_id": task_id, "models": models}
228
 
229
 
230
  @router.get("/api/results")
231
  async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
232
- """Return all completed training results for this session."""
233
  session = session_manager.get(session_id)
234
  if session is None:
235
  raise HTTPException(status_code=404, detail="Session not found")
@@ -239,11 +569,8 @@ async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
239
  }
240
 
241
 
242
- # ─── WebSocket endpoint ───────────────────────────────────────────────────────
243
-
244
  @router.websocket("/ws/train/{task_id}")
245
  async def training_websocket(websocket: WebSocket, task_id: str):
246
- """Stream live training events for a given task_id."""
247
  await ws_manager.connect(task_id, websocket)
248
  try:
249
  await ws_manager.stream(task_id, websocket)
 
1
  """
2
  Router: POST /api/train β€” launch training as a background task
3
  GET /api/results β€” fetch completed results
4
+ WS /ws/train/{task_id} β€” stream live epoch/fold metrics
5
 
6
+ Models (matching final_project_saurav_silwal.ipynb exactly):
7
+ Shallow (82-dim tabular features): SVM, LR, KNN
8
+ Deep (82-dim tabular): MLP (Keras, 3 hidden layers + dropout)
9
+ Deep (mel spectrogram): CNN, BiLSTM (Keras)
10
+
11
+ All models: Task 1 (70/30 dependent split) + Task 2 (LOIO cross-validation)
12
  """
13
 
14
  import asyncio
15
  import uuid
 
16
  import traceback
17
  from concurrent.futures import ThreadPoolExecutor
18
 
 
22
  from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
23
  from sklearn.svm import SVC
24
  from sklearn.linear_model import LogisticRegression
 
 
25
  from sklearn.neighbors import KNeighborsClassifier
26
  from fastapi import APIRouter, Header, HTTPException, WebSocket, WebSocketDisconnect
27
 
28
  from session import session_manager
29
  from ws.training_ws import ws_manager, emit
30
+ from ml.feature_extraction import (
31
+ extract_mel_spectrogram, impute_nans, FEATURE_NAMES
32
+ )
33
  from config import SEED, TEST_SIZE, IDX_TO_CLASS, CLASS_NAMES, N_CLASSES
34
 
35
+ router = APIRouter(tags=["training"])
 
 
36
  _executor = ThreadPoolExecutor(max_workers=2)
37
 
38
 
39
+ # ─────────────────────────────────────────────────────────────────────────────
40
+ # Shallow model factories (sklearn)
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
 
43
+ SHALLOW_FACTORIES = {
44
  "SVM": lambda: SVC(
45
  kernel="rbf", C=10.0, gamma="scale",
46
+ probability=True, class_weight="balanced", random_state=SEED,
47
  ),
48
  "LR": lambda: LogisticRegression(
49
  C=1.0, max_iter=2000, class_weight="balanced",
50
+ multi_class="multinomial", solver="lbfgs", random_state=SEED,
 
 
 
 
51
  ),
52
+ "KNN": lambda: KNeighborsClassifier(
53
+ n_neighbors=5, metric="euclidean", weights="uniform",
 
54
  ),
 
55
  }
56
 
57
 
58
+ # ─────────────────────────────────────────────────────────────────────────────
59
+ # Keras model builders
60
+ # ─────────────────────────────────────────────────────────────────────────────
61
+
62
+ def build_mlp(input_dim: int, n_classes: int = 3):
63
+ """3-layer MLP with BatchNorm + Dropout. Matches notebook Cell 9."""
64
+ import tensorflow as tf
65
+ from tensorflow import keras
66
+
67
+ model = keras.Sequential([
68
+ keras.layers.Input(shape=(input_dim,)),
69
+ keras.layers.Dense(256),
70
+ keras.layers.BatchNormalization(),
71
+ keras.layers.Activation("relu"),
72
+ keras.layers.Dropout(0.4),
73
+
74
+ keras.layers.Dense(128),
75
+ keras.layers.BatchNormalization(),
76
+ keras.layers.Activation("relu"),
77
+ keras.layers.Dropout(0.3),
78
+
79
+ keras.layers.Dense(64),
80
+ keras.layers.BatchNormalization(),
81
+ keras.layers.Activation("relu"),
82
+ keras.layers.Dropout(0.2),
83
+
84
+ keras.layers.Dense(n_classes, activation="softmax"),
85
+ ])
86
+ model.compile(
87
+ optimizer=keras.optimizers.Adam(learning_rate=1e-3),
88
+ loss="sparse_categorical_crossentropy",
89
+ metrics=["accuracy"],
90
+ )
91
+ return model
92
+
93
+
94
+ def build_cnn(n_mels: int = 64, n_frames: int = 128, n_classes: int = 3):
95
+ """CNN on log-mel spectrogram (64Γ—128Γ—1). Matches notebook Cell 10."""
96
+ import tensorflow as tf
97
+ from tensorflow import keras
98
+
99
+ model = keras.Sequential([
100
+ keras.layers.Input(shape=(n_mels, n_frames, 1)),
101
+
102
+ keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu"),
103
+ keras.layers.BatchNormalization(),
104
+ keras.layers.MaxPooling2D((2, 2)),
105
+ keras.layers.Dropout(0.25),
106
+
107
+ keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu"),
108
+ keras.layers.BatchNormalization(),
109
+ keras.layers.MaxPooling2D((2, 2)),
110
+ keras.layers.Dropout(0.25),
111
+
112
+ keras.layers.Conv2D(128, (3, 3), padding="same", activation="relu"),
113
+ keras.layers.BatchNormalization(),
114
+ keras.layers.GlobalAveragePooling2D(),
115
+ keras.layers.Dropout(0.4),
116
+
117
+ keras.layers.Dense(128, activation="relu"),
118
+ keras.layers.Dropout(0.3),
119
+ keras.layers.Dense(n_classes, activation="softmax"),
120
+ ])
121
+ model.compile(
122
+ optimizer="adam",
123
+ loss="sparse_categorical_crossentropy",
124
+ metrics=["accuracy"],
125
+ )
126
+ return model
127
+
128
+
129
+ def build_bilstm(n_frames: int = 128, n_mels: int = 64, n_classes: int = 3):
130
+ """Bidirectional LSTM on mel sequences (128 time steps Γ— 64 mel features).
131
+ Matches notebook Cell 11."""
132
+ import tensorflow as tf
133
+ from tensorflow import keras
134
+
135
+ model = keras.Sequential([
136
+ keras.layers.Input(shape=(n_frames, n_mels)),
137
+ keras.layers.Bidirectional(keras.layers.LSTM(64, return_sequences=True)),
138
+ keras.layers.Dropout(0.3),
139
+ keras.layers.Bidirectional(keras.layers.LSTM(32)),
140
+ keras.layers.Dropout(0.3),
141
+ keras.layers.Dense(64, activation="relu"),
142
+ keras.layers.Dropout(0.2),
143
+ keras.layers.Dense(n_classes, activation="softmax"),
144
+ ])
145
+ model.compile(
146
+ optimizer="adam",
147
+ loss="sparse_categorical_crossentropy",
148
+ metrics=["accuracy"],
149
+ )
150
+ return model
151
+
152
+
153
+ # ─────────────────────────────────────────────────────────────────────────────
154
+ # WebSocket epoch callback for Keras
155
+ # ─────────────────────────────────────────────────────────────────────────────
156
+
157
+ class WSCallback:
158
+ """Keras callback that emits epoch metrics over WebSocket."""
159
+
160
+ def __init__(self, loop, queue, model_name, total_epochs):
161
+ self.loop = loop
162
+ self.queue = queue
163
+ self.model_name = model_name
164
+ self.total_epochs = total_epochs
165
+
166
+ def on_epoch_end(self, epoch, logs=None):
167
+ logs = logs or {}
168
+ emit(self.loop, self.queue, {
169
+ "type": "epoch",
170
+ "model": self.model_name,
171
+ "epoch": epoch + 1,
172
+ "total": self.total_epochs,
173
+ "train_acc": round(float(logs.get("accuracy", 0)), 4),
174
+ "val_acc": round(float(logs.get("val_accuracy", 0)), 4),
175
+ "train_loss": round(float(logs.get("loss", 0)), 4),
176
+ "val_loss": round(float(logs.get("val_loss", 0)), 4),
177
+ })
178
+
179
+
180
+ def _make_keras_callback(loop, queue, model_name, total_epochs):
181
+ """Return a tf.keras.callbacks.Callback subclass instance."""
182
+ import tensorflow as tf
183
 
184
+ cb = WSCallback(loop, queue, model_name, total_epochs)
185
+
186
+ class _CB(tf.keras.callbacks.Callback):
187
+ def on_epoch_end(self, epoch, logs=None):
188
+ cb.on_epoch_end(epoch, logs)
189
+
190
+ return _CB()
191
+
192
+
193
+ # ─────────────────────────────────────────────────────────────────────────────
194
+ # Shallow model training (SVM / LR / KNN)
195
+ # ─────────────────────────────────────────────────────────────────────────────
196
+
197
+ def _train_shallow(model_name, X, y, groups, loop, queue, session):
198
  try:
199
  scaler = StandardScaler()
200
+ X_s = scaler.fit_transform(X)
201
 
202
+ # Task 1
203
  X_tr, X_te, y_tr, y_te = train_test_split(
204
  X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
205
  )
206
+ clf = SHALLOW_FACTORIES[model_name]()
207
+ clf.fit(X_tr, y_tr)
208
+ y_p1 = clf.predict(X_te)
209
+ acc_t1 = float(accuracy_score(y_te, y_p1))
210
+ f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
211
+ cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
212
+
213
+ emit(loop, queue, {"type": "task1_done", "model": model_name,
214
+ "acc": round(acc_t1, 4), "f1": round(f1_t1, 4)})
215
+
216
+ # Task 2 β€” LOIO
217
+ logo = LeaveOneGroupOut()
218
+ fold_accs = []
219
+ fold_records = []
220
+ all_yt, all_yp = [], []
 
221
 
222
  for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
223
  flange_out = int(groups[te_idx[0]])
224
+ clf2 = SHALLOW_FACTORIES[model_name]()
225
+ clf2.fit(X_s[tr_idx], y[tr_idx])
226
+ yp = clf2.predict(X_s[te_idx])
227
+ acc_f = float(accuracy_score(y[te_idx], yp))
 
228
  fold_accs.append(acc_f)
229
+ fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
230
+ "acc": round(acc_f, 4), "n_test": len(te_idx)})
231
+ all_yt.extend(y[te_idx].tolist())
232
+ all_yp.extend(yp.tolist())
233
+ emit(loop, queue, {"type": "fold_done", "model": model_name,
234
+ "fold": fold_i + 1, "flange_out": flange_out,
235
+ "acc": round(acc_f, 4)})
236
+
237
+ cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
238
+ f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
239
+
240
+ # Final model on all data
241
+ clf_final = SHALLOW_FACTORIES[model_name]()
 
 
 
242
  clf_final.fit(X_s, y)
243
+ train_acc = float(accuracy_score(y, clf_final.predict(X_s)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  result = {
246
  "model": model_name,
 
248
  "task1_f1": round(f1_t1, 4),
249
  "task1_cm": cm_t1,
250
  "task2_mean": round(float(np.mean(fold_accs)), 4),
251
+ "task2_std": round(float(np.std(fold_accs)), 4),
252
  "task2_f1": round(f1_t2, 4),
253
  "task2_cm": cm_t2,
254
  "folds": fold_records,
255
  "train_acc": round(train_acc, 4),
 
 
256
  }
257
  session.training_results[model_name] = result
258
  session.touch()
259
+ emit(loop, queue, {"type": "model_done", "model": model_name, **result})
260
+
261
+ except Exception as e:
262
+ emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
263
+ traceback.print_exc()
264
 
265
+
266
+ # ─────────────────────────────────────────────────────────────────────────────
267
+ # MLP training (Keras, tabular features)
268
+ # ─────────────────────────────────────────────────────────────────────────────
269
+
270
+ def _train_mlp(X, y, groups, loop, queue, session, epochs=60):
271
+ model_name = "MLP"
272
+ try:
273
+ import tensorflow as tf
274
+ tf.random.set_seed(SEED)
275
+
276
+ scaler = StandardScaler()
277
+ X_s = scaler.fit_transform(X)
278
+
279
+ # Task 1
280
+ X_tr, X_te, y_tr, y_te = train_test_split(
281
+ X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
282
+ )
283
+ model = build_mlp(X_s.shape[1])
284
+ cb = _make_keras_callback(loop, queue, model_name, epochs)
285
+ es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
286
+ model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
287
+ validation_split=0.15,
288
+ callbacks=[cb, es], verbose=0)
289
+
290
+ y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
291
+ acc_t1 = float(accuracy_score(y_te, y_p1))
292
+ f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
293
+ cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
294
+
295
+ emit(loop, queue, {"type": "task1_done", "model": model_name,
296
+ "acc": round(acc_t1, 4), "f1": round(f1_t1, 4)})
297
+
298
+ # Task 2 β€” LOIO
299
+ logo = LeaveOneGroupOut()
300
+ fold_accs = []
301
+ fold_records = []
302
+ all_yt, all_yp = [], []
303
+
304
+ for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
305
+ flange_out = int(groups[te_idx[0]])
306
+ m2 = build_mlp(X_s.shape[1])
307
+ es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
308
+ m2.fit(X_s[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
309
+ validation_split=0.15, callbacks=[es2], verbose=0)
310
+ yp = np.argmax(m2.predict(X_s[te_idx], verbose=0), axis=1)
311
+ acc_f = float(accuracy_score(y[te_idx], yp))
312
+ fold_accs.append(acc_f)
313
+ fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
314
+ "acc": round(acc_f, 4), "n_test": len(te_idx)})
315
+ all_yt.extend(y[te_idx].tolist())
316
+ all_yp.extend(yp.tolist())
317
+ emit(loop, queue, {"type": "fold_done", "model": model_name,
318
+ "fold": fold_i + 1, "flange_out": flange_out,
319
+ "acc": round(acc_f, 4)})
320
+
321
+ cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
322
+ f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
323
+ train_acc = float(accuracy_score(y, np.argmax(model.predict(X_s, verbose=0), axis=1)))
324
+
325
+ result = {
326
+ "model": model_name, "task1_acc": round(acc_t1, 4),
327
+ "task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
328
+ "task2_mean": round(float(np.mean(fold_accs)), 4),
329
+ "task2_std": round(float(np.std(fold_accs)), 4),
330
+ "task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
331
+ "folds": fold_records, "train_acc": round(train_acc, 4),
332
+ }
333
+ session.training_results[model_name] = result
334
+ session.touch()
335
  emit(loop, queue, {"type": "model_done", "model": model_name, **result})
336
 
337
  except Exception as e:
 
339
  traceback.print_exc()
340
 
341
 
342
+ # ─────────────────────────────────────────────────────────────────────────────
343
+ # CNN training (Keras, mel spectrograms)
344
+ # ─────────────────────────────────────────────────────────────────────────────
345
+
346
+ def _train_cnn(waveforms, y, groups, loop, queue, session, epochs=50):
347
+ model_name = "CNN"
348
+ try:
349
+ import tensorflow as tf
350
+ tf.random.set_seed(SEED)
351
+
352
+ # Build spectrogram tensor (N, 64, 128, 1)
353
+ emit(loop, queue, {"type": "task1_done", "model": model_name,
354
+ "acc": 0.0, "f1": 0.0, "message": "Extracting spectrograms..."})
355
+ X_spec = np.stack([
356
+ extract_mel_spectrogram(np.array(w, dtype=np.float32))
357
+ for w in waveforms
358
+ ], axis=0)[..., np.newaxis] # (N, 64, 128, 1)
359
+
360
+ # Task 1
361
+ X_tr, X_te, y_tr, y_te = train_test_split(
362
+ X_spec, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
363
+ )
364
+ model = build_cnn()
365
+ cb = _make_keras_callback(loop, queue, model_name, epochs)
366
+ es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
367
+ model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
368
+ validation_split=0.15, callbacks=[cb, es], verbose=0)
369
+
370
+ y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
371
+ acc_t1 = float(accuracy_score(y_te, y_p1))
372
+ f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
373
+ cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
374
+
375
+ # Task 2 β€” LOIO
376
+ logo = LeaveOneGroupOut()
377
+ fold_accs, fold_records, all_yt, all_yp = [], [], [], []
378
+
379
+ for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_spec, y, groups)):
380
+ flange_out = int(groups[te_idx[0]])
381
+ m2 = build_cnn()
382
+ es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
383
+ m2.fit(X_spec[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
384
+ validation_split=0.15, callbacks=[es2], verbose=0)
385
+ yp = np.argmax(m2.predict(X_spec[te_idx], verbose=0), axis=1)
386
+ acc_f = float(accuracy_score(y[te_idx], yp))
387
+ fold_accs.append(acc_f)
388
+ fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
389
+ "acc": round(acc_f, 4), "n_test": len(te_idx)})
390
+ all_yt.extend(y[te_idx].tolist())
391
+ all_yp.extend(yp.tolist())
392
+ emit(loop, queue, {"type": "fold_done", "model": model_name,
393
+ "fold": fold_i + 1, "flange_out": flange_out,
394
+ "acc": round(acc_f, 4)})
395
+
396
+ cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
397
+ f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
398
+ train_acc = float(accuracy_score(y, np.argmax(model.predict(X_spec, verbose=0), axis=1)))
399
+
400
+ result = {
401
+ "model": model_name, "task1_acc": round(acc_t1, 4),
402
+ "task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
403
+ "task2_mean": round(float(np.mean(fold_accs)), 4),
404
+ "task2_std": round(float(np.std(fold_accs)), 4),
405
+ "task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
406
+ "folds": fold_records, "train_acc": round(train_acc, 4),
407
+ }
408
+ session.training_results[model_name] = result
409
+ session.touch()
410
+ emit(loop, queue, {"type": "model_done", "model": model_name, **result})
411
+
412
+ except Exception as e:
413
+ emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
414
+ traceback.print_exc()
415
+
416
+
417
+ # ─────────────────────────────────────────────────────────────────────────────
418
+ # BiLSTM training (Keras, mel sequences)
419
+ # ─────────────────────────────────────────────────────────────────────────────
420
+
421
+ def _train_bilstm(waveforms, y, groups, loop, queue, session, epochs=50):
422
+ model_name = "BiLSTM"
423
+ try:
424
+ import tensorflow as tf
425
+ tf.random.set_seed(SEED)
426
+
427
+ # Reshape to (N, 128, 64) β€” time steps Γ— mel features
428
+ X_seq = np.stack([
429
+ extract_mel_spectrogram(np.array(w, dtype=np.float32)).T # (128, 64)
430
+ for w in waveforms
431
+ ], axis=0)
432
+
433
+ # Task 1
434
+ X_tr, X_te, y_tr, y_te = train_test_split(
435
+ X_seq, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
436
+ )
437
+ model = build_bilstm()
438
+ cb = _make_keras_callback(loop, queue, model_name, epochs)
439
+ es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
440
+ model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
441
+ validation_split=0.15, callbacks=[cb, es], verbose=0)
442
+
443
+ y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
444
+ acc_t1 = float(accuracy_score(y_te, y_p1))
445
+ f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
446
+ cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
447
+
448
+ # Task 2 β€” LOIO
449
+ logo = LeaveOneGroupOut()
450
+ fold_accs, fold_records, all_yt, all_yp = [], [], [], []
451
+
452
+ for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_seq, y, groups)):
453
+ flange_out = int(groups[te_idx[0]])
454
+ m2 = build_bilstm()
455
+ es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
456
+ m2.fit(X_seq[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
457
+ validation_split=0.15, callbacks=[es2], verbose=0)
458
+ yp = np.argmax(m2.predict(X_seq[te_idx], verbose=0), axis=1)
459
+ acc_f = float(accuracy_score(y[te_idx], yp))
460
+ fold_accs.append(acc_f)
461
+ fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
462
+ "acc": round(acc_f, 4), "n_test": len(te_idx)})
463
+ all_yt.extend(y[te_idx].tolist())
464
+ all_yp.extend(yp.tolist())
465
+ emit(loop, queue, {"type": "fold_done", "model": model_name,
466
+ "fold": fold_i + 1, "flange_out": flange_out,
467
+ "acc": round(acc_f, 4)})
468
+
469
+ cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
470
+ f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
471
+ train_acc = float(accuracy_score(y, np.argmax(model.predict(X_seq, verbose=0), axis=1)))
472
+
473
+ result = {
474
+ "model": model_name, "task1_acc": round(acc_t1, 4),
475
+ "task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
476
+ "task2_mean": round(float(np.mean(fold_accs)), 4),
477
+ "task2_std": round(float(np.std(fold_accs)), 4),
478
+ "task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
479
+ "folds": fold_records, "train_acc": round(train_acc, 4),
480
+ }
481
+ session.training_results[model_name] = result
482
+ session.touch()
483
+ emit(loop, queue, {"type": "model_done", "model": model_name, **result})
484
+
485
+ except Exception as e:
486
+ emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
487
+ traceback.print_exc()
488
+
489
+
490
+ # ─────────────────────────────────────────────────────────────────────────────
491
+ # Master training thread
492
+ # ─────────────────────────────────────────────────────────────────────────────
493
+
494
+ ALL_MODELS = ["SVM", "LR", "KNN", "MLP", "CNN", "BiLSTM"]
495
+
496
+ def _train_all(task_id: str, session_id: str, models: list[str]):
497
  loop = asyncio.new_event_loop()
498
  asyncio.set_event_loop(loop)
499
 
500
  session = session_manager.get(session_id)
501
  if session is None:
502
  return
 
503
  queue = ws_manager.get_queue(task_id)
504
  if queue is None:
505
  return
 
509
  emit(loop, queue, {"type": "error", "message": "Features not extracted yet"})
510
  return
511
 
512
+ X = np.array(feats["X_feat"], dtype=np.float32)
513
+ y = np.array(feats["labels"], dtype=np.int64)
514
+ groups = np.array(feats["flange_groups"], dtype=np.int64)
515
 
516
+ # NaN imputation (tau column can have NaNs)
517
+ X = impute_nans(X, y)
518
+
519
+ waveforms = session.hits.get("waveforms", [])
520
+
521
+ for m in models:
522
+ if m in SHALLOW_FACTORIES:
523
+ _train_shallow(m, X, y, groups, loop, queue, session)
524
+ elif m == "MLP":
525
+ _train_mlp(X, y, groups, loop, queue, session)
526
+ elif m == "CNN":
527
+ _train_cnn(waveforms, y, groups, loop, queue, session)
528
+ elif m == "BiLSTM":
529
+ _train_bilstm(waveforms, y, groups, loop, queue, session)
530
 
531
  emit(loop, queue, {"type": "all_done", "task_id": task_id})
532
 
533
 
534
+ # ─────────────────────────────────────────────────────────────────────────────
535
+ # Routes
536
+ # ─────────────────────────────────────────────────────────────────────────────
537
 
538
  @router.post("/api/train")
539
  async def start_training(
540
  session_id: str = Header(..., alias="X-Session-Id"),
541
  body: dict = None,
542
  ):
 
 
 
 
543
  session = session_manager.get(session_id)
544
  if session is None:
545
  raise HTTPException(status_code=404, detail="Session not found")
546
  if not session.features:
547
  raise HTTPException(status_code=400, detail="Extract features first: POST /api/features")
548
 
549
+ models = (body or {}).get("models", ALL_MODELS)
 
550
  task_id = str(uuid.uuid4())
 
551
  ws_manager.create_queue(task_id)
552
  session.training_tasks[task_id] = models
553
  session.touch()
554
 
555
  loop = asyncio.get_event_loop()
556
+ loop.run_in_executor(_executor, _train_all, task_id, session_id, models)
 
 
 
 
 
 
557
 
558
  return {"task_id": task_id, "models": models}
559
 
560
 
561
  @router.get("/api/results")
562
  async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
 
563
  session = session_manager.get(session_id)
564
  if session is None:
565
  raise HTTPException(status_code=404, detail="Session not found")
 
569
  }
570
 
571
 
 
 
572
  @router.websocket("/ws/train/{task_id}")
573
  async def training_websocket(websocket: WebSocket, task_id: str):
 
574
  await ws_manager.connect(task_id, websocket)
575
  try:
576
  await ws_manager.stream(task_id, websocket)