Spaces:
Sleeping
Sleeping
Fix feature extraction and models to match notebook
Browse files- ml/feature_extraction.py +202 -144
- routers/training.py +446 -119
ml/feature_extraction.py
CHANGED
|
@@ -1,162 +1,220 @@
|
|
| 1 |
"""
|
| 2 |
-
Feature extraction β 82-dimensional
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
[78] Crest factor (peak/RMS)
|
| 11 |
-
[79] Log peak amplitude
|
| 12 |
-
[80:82] Spectral centroid mean + std
|
| 13 |
"""
|
| 14 |
|
| 15 |
import warnings
|
| 16 |
import numpy as np
|
| 17 |
import librosa
|
| 18 |
from scipy.signal import welch
|
| 19 |
-
from scipy.optimize import curve_fit
|
| 20 |
|
| 21 |
-
from config import SR,
|
| 22 |
|
| 23 |
-
# ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
for i in range(n_bins):
|
| 35 |
-
mask
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
total
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
"""
|
| 48 |
-
Returns (mfcc_mean, mfcc_std) β both shape (n_mfcc,).
|
| 49 |
-
Mean subtraction (per coefficient across the window) is applied first β
|
| 50 |
-
this is the single most impactful normalisation for cross-session robustness.
|
| 51 |
-
"""
|
| 52 |
with warnings.catch_warnings():
|
| 53 |
warnings.simplefilter("ignore")
|
| 54 |
-
mfcc
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
return
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# βββ Decay ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
-
|
| 62 |
-
def _exp_model(t, A, tau):
|
| 63 |
-
return A * np.exp(-t / tau)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def extract_decay(window: np.ndarray) -> float:
|
| 67 |
-
"""
|
| 68 |
-
Fit AΒ·exp(-t/Ο) to the RMS energy envelope.
|
| 69 |
-
Returns Ο (seconds). Larger Ο = slower decay = tighter flange.
|
| 70 |
-
Falls back to 0.05 on fit failure.
|
| 71 |
-
"""
|
| 72 |
-
# RMS in 5ms frames
|
| 73 |
-
frame_len = int(0.005 * SR)
|
| 74 |
-
n_frames = len(window) // frame_len
|
| 75 |
-
rms_env = np.array([
|
| 76 |
-
np.sqrt(np.mean(window[i * frame_len:(i + 1) * frame_len] ** 2))
|
| 77 |
-
for i in range(n_frames)
|
| 78 |
-
], dtype=np.float32)
|
| 79 |
-
t = np.arange(n_frames) * 0.005
|
| 80 |
-
# Trim to non-zero region
|
| 81 |
-
thresh = 0.02 * rms_env.max()
|
| 82 |
-
valid = rms_env > thresh
|
| 83 |
-
if valid.sum() < 5:
|
| 84 |
-
return 0.05
|
| 85 |
-
try:
|
| 86 |
-
p0 = (rms_env[valid].max(), 0.15)
|
| 87 |
-
popt, _ = curve_fit(_exp_model, t[valid], rms_env[valid], p0=p0,
|
| 88 |
-
bounds=([0, 0.001], [np.inf, 2.0]), maxfev=2000)
|
| 89 |
-
tau = float(np.clip(popt[1], 0.001, 2.0))
|
| 90 |
-
except Exception:
|
| 91 |
-
tau = 0.05
|
| 92 |
-
return tau
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# βββ Energy ratio βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
-
|
| 97 |
-
def extract_energy_ratio(window: np.ndarray, split_ms: float = 50.0) -> float:
|
| 98 |
-
"""
|
| 99 |
-
E_late / E_early where split is at split_ms ms after the hit onset.
|
| 100 |
-
High ratio β flange still ringing β tight.
|
| 101 |
-
"""
|
| 102 |
-
split_n = int(split_ms / 1000 * SR)
|
| 103 |
-
early = float(np.sum(window[:split_n] ** 2)) + 1e-12
|
| 104 |
-
late = float(np.sum(window[split_n:] ** 2))
|
| 105 |
-
return min(late / early, 100.0) # cap for numerical safety
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# βββ Spectral centroid ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
-
|
| 110 |
-
def extract_spectral_centroid(window: np.ndarray) -> tuple[float, float]:
|
| 111 |
-
cent = librosa.feature.spectral_centroid(y=window, sr=SR, n_fft=N_FFT)[0]
|
| 112 |
-
return float(cent.mean()), float(cent.std())
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# βββ Full 82-dim vector βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
-
|
| 117 |
-
def extract_features(window: np.ndarray) -> np.ndarray:
|
| 118 |
-
"""Return 82-dim feature vector for a single hit window."""
|
| 119 |
-
psd = extract_psd(window) # 50
|
| 120 |
-
mfcc_m, mfcc_s = extract_mfcc(window) # 13 + 13
|
| 121 |
-
tau = extract_decay(window) # 1
|
| 122 |
-
energy_ratio = extract_energy_ratio(window) # 1
|
| 123 |
-
peak_amp = float(np.abs(window).max())
|
| 124 |
-
rms = float(np.sqrt(np.mean(window ** 2))) + 1e-12
|
| 125 |
-
crest = float(peak_amp / rms) # 1
|
| 126 |
-
log_peak = float(np.log1p(peak_amp)) # 1
|
| 127 |
-
sc_mean, sc_std = extract_spectral_centroid(window) # 2
|
| 128 |
-
|
| 129 |
-
feat = np.concatenate([
|
| 130 |
-
psd,
|
| 131 |
-
mfcc_m,
|
| 132 |
-
mfcc_s,
|
| 133 |
-
[tau, energy_ratio, crest, log_peak, sc_mean, sc_std],
|
| 134 |
]).astype(np.float32)
|
| 135 |
-
return feat
|
| 136 |
-
|
| 137 |
|
| 138 |
-
FEATURE_NAMES: list[str] = (
|
| 139 |
-
[f"psd_{i}" for i in range(50)] +
|
| 140 |
-
[f"mfcc_mean_{i}" for i in range(13)] +
|
| 141 |
-
[f"mfcc_std_{i}" for i in range(13)] +
|
| 142 |
-
["tau", "energy_ratio", "crest_factor", "log_peak_amp", "sc_mean", "sc_std"]
|
| 143 |
-
)
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
)
|
| 156 |
-
mel_db = librosa.power_to_db(mel, ref=np.max)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
mel_db
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Feature extraction β 82-dimensional physics-informed vector.
|
| 3 |
+
EXACTLY matches Cell 4 of final_project_saurav_silwal.ipynb.
|
| 4 |
+
|
| 5 |
+
Group 1: Relative PSD in 20 log-spaced bins (50 Hz β 8 kHz) β 20 dims
|
| 6 |
+
Group 2: MFCC mean/std + delta MFCC mean/std (13 coeffs each) β 52 dims
|
| 7 |
+
Group 3: Physics features (centroid, bandwidth, rolloff, ZCR,
|
| 8 |
+
peak freq, decay Ο, energy ratio, RMS, Q-factor) β 10 dims
|
| 9 |
+
Total β 82 dims
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import warnings
|
| 13 |
import numpy as np
|
| 14 |
import librosa
|
| 15 |
from scipy.signal import welch
|
|
|
|
| 16 |
|
| 17 |
+
from config import SR, N_MELS, N_FFT, HOP_LENGTH_MEL, SPEC_TIME_FRAMES
|
| 18 |
|
| 19 |
+
# ββ Constants matching notebook ββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
N_PSD_BINS = 20
|
| 21 |
+
PSD_FMIN = 50.0
|
| 22 |
+
PSD_FMAX = 8000.0
|
| 23 |
+
WELCH_NPERSEG = 2048
|
| 24 |
+
WELCH_NOVERLAP = 1024
|
| 25 |
|
| 26 |
+
N_MFCC = 13
|
| 27 |
+
MFCC_NFFT = 2048
|
| 28 |
+
MFCC_HOP = 512
|
| 29 |
+
|
| 30 |
+
DECAY_FIT_MS = 200
|
| 31 |
+
EARLY_LATE_FRAC = 0.20
|
| 32 |
+
PEAK_SAMPLE_IN_WIN = int(0.020 * SR) # 960 samples = 20 ms pre-peak
|
| 33 |
+
|
| 34 |
+
FMIN_MEL = 0
|
| 35 |
+
FMAX_MEL = SR // 2 # Nyquist = 24 000 Hz
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ββ Group 1: Relative PSD βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
|
| 40 |
+
def relative_psd_log_bins(y, sr=SR, n_bins=N_PSD_BINS,
|
| 41 |
+
f_min=PSD_FMIN, f_max=PSD_FMAX):
|
| 42 |
+
"""Welch PSD β 20 log-spaced bins β normalized so sum=1."""
|
| 43 |
+
f, pxx = welch(y, fs=sr,
|
| 44 |
+
nperseg=min(WELCH_NPERSEG, len(y)),
|
| 45 |
+
noverlap=min(WELCH_NOVERLAP, len(y) // 2))
|
| 46 |
+
edges = np.logspace(np.log10(f_min), np.log10(f_max), n_bins + 1)
|
| 47 |
+
bins = np.zeros(n_bins, dtype=np.float32)
|
| 48 |
for i in range(n_bins):
|
| 49 |
+
mask = (f >= edges[i]) & (f < edges[i + 1])
|
| 50 |
+
bins[i] = pxx[mask].sum()
|
| 51 |
+
total = bins.sum()
|
| 52 |
+
if total > 1e-20:
|
| 53 |
+
bins /= total
|
| 54 |
+
return bins, f, pxx
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ββ Group 2: MFCC + delta statistics ββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
|
| 59 |
+
def mfcc_stats(y, sr=SR):
|
| 60 |
+
"""13 MFCCs β mean+std (26) + delta mean+std (26) = 52 dims."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
with warnings.catch_warnings():
|
| 62 |
warnings.simplefilter("ignore")
|
| 63 |
+
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC,
|
| 64 |
+
n_fft=MFCC_NFFT, hop_length=MFCC_HOP)
|
| 65 |
+
delta = librosa.feature.delta(mfcc)
|
| 66 |
+
return np.concatenate([
|
| 67 |
+
mfcc.mean(axis=1), mfcc.std(axis=1),
|
| 68 |
+
delta.mean(axis=1), delta.std(axis=1),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
]).astype(np.float32)
|
|
|
|
|
|
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# ββ Group 3: Physics features ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
def peak_frequency(f, pxx):
|
| 75 |
+
if pxx.max() <= 0:
|
| 76 |
+
return 0.0
|
| 77 |
+
return float(f[np.argmax(pxx)])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def q_factor(f, pxx):
|
| 81 |
+
"""Q = f_peak / -3 dB bandwidth. High Q = tight (rings cleanly)."""
|
| 82 |
+
if pxx.max() <= 0:
|
| 83 |
+
return 0.0
|
| 84 |
+
pdb = 10 * np.log10(pxx + 1e-20)
|
| 85 |
+
peak_idx = int(np.argmax(pdb))
|
| 86 |
+
threshold = pdb[peak_idx] - 3.0
|
| 87 |
+
L = peak_idx
|
| 88 |
+
while L > 0 and pdb[L] >= threshold:
|
| 89 |
+
L -= 1
|
| 90 |
+
R = peak_idx
|
| 91 |
+
while R < len(pdb) - 1 and pdb[R] >= threshold:
|
| 92 |
+
R += 1
|
| 93 |
+
bw = max(f[R] - f[L], 1.0)
|
| 94 |
+
return float(f[peak_idx] / bw)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def decay_tau(y, peak_sample=PEAK_SAMPLE_IN_WIN, sr=SR, fit_ms=DECAY_FIT_MS):
|
| 98 |
+
"""Decay time constant Ο. Loose β small Ο. Tight β large Ο."""
|
| 99 |
+
n_fit = int(fit_ms * sr / 1000)
|
| 100 |
+
seg = y[peak_sample:min(peak_sample + n_fit, len(y))]
|
| 101 |
+
if len(seg) < 100:
|
| 102 |
+
return np.nan
|
| 103 |
+
env_w = max(1, int(0.005 * sr))
|
| 104 |
+
env = np.convolve(np.abs(seg), np.ones(env_w) / env_w, mode='same')
|
| 105 |
+
if env.max() < 1e-8:
|
| 106 |
+
return np.nan
|
| 107 |
+
active = np.where(env > 0.05 * env.max())[0]
|
| 108 |
+
if len(active) < 50:
|
| 109 |
+
return np.nan
|
| 110 |
+
n_active = active[-1] + 1
|
| 111 |
+
eps = env.max() * 1e-4
|
| 112 |
+
log_env = np.log(env[:n_active] + eps)
|
| 113 |
+
t = np.arange(n_active) / sr
|
| 114 |
+
slope, _ = np.polyfit(t, log_env, 1)
|
| 115 |
+
if slope >= 0:
|
| 116 |
+
return np.nan
|
| 117 |
+
tau = -1.0 / slope
|
| 118 |
+
return float(tau) if 0.001 <= tau <= 10.0 else np.nan
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def energy_ratio(y, frac=EARLY_LATE_FRAC):
|
| 122 |
+
"""E_late / E_early. Tight flanges still ringing β high ratio."""
|
| 123 |
+
n_chunk = int(frac * len(y))
|
| 124 |
+
e_early = np.sqrt(np.mean(y[:n_chunk] ** 2))
|
| 125 |
+
e_late = np.sqrt(np.mean(y[-n_chunk:] ** 2))
|
| 126 |
+
return float(e_late / (e_early + 1e-12))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ββ Master 82-dim extractor βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
|
| 131 |
+
def extract_features(y: np.ndarray, sr: int = SR) -> np.ndarray:
|
| 132 |
+
"""Return 82-dim feature vector for one hit window."""
|
| 133 |
+
psd_bins, f_psd, pxx = relative_psd_log_bins(y, sr) # 20
|
| 134 |
+
|
| 135 |
+
cepstral = mfcc_stats(y, sr) # 52
|
| 136 |
|
| 137 |
+
with warnings.catch_warnings():
|
| 138 |
+
warnings.simplefilter("ignore")
|
| 139 |
+
sc = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
|
| 140 |
+
sb = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
|
| 141 |
+
sr85 = librosa.feature.spectral_rolloff(y=y, sr=sr, roll_percent=0.85)[0]
|
| 142 |
+
zcr = librosa.feature.zero_crossing_rate(y)[0]
|
| 143 |
+
|
| 144 |
+
physics = np.array([
|
| 145 |
+
sc.mean(), sc.std(), # 2: centroid mean/std
|
| 146 |
+
sb.mean(), # 1: bandwidth mean
|
| 147 |
+
sr85.mean(), # 1: rolloff 85%
|
| 148 |
+
zcr.mean(), # 1: zero-crossing rate
|
| 149 |
+
peak_frequency(f_psd, pxx), # 1: dominant freq
|
| 150 |
+
decay_tau(y), # 1: Ο (NaN β imputed after)
|
| 151 |
+
energy_ratio(y), # 1: E_late / E_early
|
| 152 |
+
float(np.sqrt(np.mean(y ** 2))),# 1: RMS energy
|
| 153 |
+
q_factor(f_psd, pxx), # 1: Q-factor
|
| 154 |
+
], dtype=np.float32) # 10 total
|
| 155 |
+
|
| 156 |
+
return np.concatenate([psd_bins, cepstral, physics])
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def impute_nans(X: np.ndarray, y_labels: np.ndarray, n_classes: int = 3) -> np.ndarray:
|
| 160 |
+
"""Per-class median imputation for NaN columns (tau can be NaN)."""
|
| 161 |
+
X = X.copy()
|
| 162 |
+
nan_cols = np.where(np.isnan(X).any(axis=0))[0]
|
| 163 |
+
for c in nan_cols:
|
| 164 |
+
for cls in range(n_classes):
|
| 165 |
+
cls_mask = (y_labels == cls)
|
| 166 |
+
median_val = float(np.nanmedian(X[cls_mask, c]))
|
| 167 |
+
if np.isnan(median_val):
|
| 168 |
+
median_val = float(np.nanmedian(X[:, c]))
|
| 169 |
+
fill_mask = cls_mask & np.isnan(X[:, c])
|
| 170 |
+
X[fill_mask, c] = median_val
|
| 171 |
+
return X
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ββ Feature name list βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
|
| 176 |
+
def _build_feature_names() -> list[str]:
|
| 177 |
+
names = []
|
| 178 |
+
edges = np.logspace(np.log10(PSD_FMIN), np.log10(PSD_FMAX), N_PSD_BINS + 1)
|
| 179 |
+
for i in range(N_PSD_BINS):
|
| 180 |
+
names.append(f'psd_{edges[i]:.0f}_{edges[i+1]:.0f}Hz')
|
| 181 |
+
names += [f'mfcc{i:02d}_mean' for i in range(N_MFCC)]
|
| 182 |
+
names += [f'mfcc{i:02d}_std' for i in range(N_MFCC)]
|
| 183 |
+
names += [f'dmfcc{i:02d}_mean' for i in range(N_MFCC)]
|
| 184 |
+
names += [f'dmfcc{i:02d}_std' for i in range(N_MFCC)]
|
| 185 |
+
names += ['spec_centroid_mean', 'spec_centroid_std',
|
| 186 |
+
'spec_bandwidth_mean', 'spec_rolloff85_mean',
|
| 187 |
+
'zero_cross_rate_mean', 'peak_frequency',
|
| 188 |
+
'decay_tau', 'energy_ratio', 'rms_energy', 'q_factor']
|
| 189 |
+
return names
|
| 190 |
+
|
| 191 |
+
FEATURE_NAMES: list[str] = _build_feature_names()
|
| 192 |
+
assert len(FEATURE_NAMES) == 82
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ββ Mel spectrogram (CNN / BiLSTM input) β matches notebook Cell 5 ββββββββ
|
| 196 |
+
|
| 197 |
+
def extract_mel_spectrogram(y: np.ndarray, sr: int = SR,
|
| 198 |
+
n_mels: int = N_MELS,
|
| 199 |
+
n_fft: int = N_FFT,
|
| 200 |
+
hop_length: int = HOP_LENGTH_MEL,
|
| 201 |
+
target_frames: int = SPEC_TIME_FRAMES) -> np.ndarray:
|
| 202 |
+
"""One hit β standardized log-mel spectrogram of shape (n_mels, target_frames)."""
|
| 203 |
+
mel = librosa.feature.melspectrogram(
|
| 204 |
+
y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
|
| 205 |
+
fmin=FMIN_MEL, fmax=FMAX_MEL, power=2.0,
|
| 206 |
)
|
| 207 |
+
mel_db = librosa.power_to_db(mel, ref=np.max).astype(np.float32)
|
| 208 |
+
|
| 209 |
+
n_frames = mel_db.shape[1]
|
| 210 |
+
if n_frames < target_frames:
|
| 211 |
+
pad_val = float(mel_db.min())
|
| 212 |
+
mel_db = np.pad(mel_db, ((0, 0), (0, target_frames - n_frames)),
|
| 213 |
+
mode='constant', constant_values=pad_val)
|
| 214 |
+
elif n_frames > target_frames:
|
| 215 |
+
start = (n_frames - target_frames) // 2
|
| 216 |
+
mel_db = mel_db[:, start:start + target_frames]
|
| 217 |
+
|
| 218 |
+
# per-sample standardize
|
| 219 |
+
mu, sigma = mel_db.mean(), mel_db.std()
|
| 220 |
+
return (mel_db - mu) / (sigma + 1e-6) # shape (64, 128)
|
routers/training.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
Router: POST /api/train β launch training as a background task
|
| 3 |
GET /api/results β fetch completed results
|
| 4 |
-
WS /ws/train/{task_id} β stream live epoch metrics
|
| 5 |
|
| 6 |
-
Models
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import asyncio
|
| 11 |
import uuid
|
| 12 |
-
import threading
|
| 13 |
import traceback
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
|
|
@@ -19,123 +22,225 @@ from sklearn.preprocessing import StandardScaler
|
|
| 19 |
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
| 20 |
from sklearn.svm import SVC
|
| 21 |
from sklearn.linear_model import LogisticRegression
|
| 22 |
-
from sklearn.ensemble import RandomForestClassifier
|
| 23 |
-
from sklearn.neural_network import MLPClassifier
|
| 24 |
from sklearn.neighbors import KNeighborsClassifier
|
| 25 |
from fastapi import APIRouter, Header, HTTPException, WebSocket, WebSocketDisconnect
|
| 26 |
|
| 27 |
from session import session_manager
|
| 28 |
from ws.training_ws import ws_manager, emit
|
|
|
|
|
|
|
|
|
|
| 29 |
from config import SEED, TEST_SIZE, IDX_TO_CLASS, CLASS_NAMES, N_CLASSES
|
| 30 |
|
| 31 |
-
router
|
| 32 |
-
|
| 33 |
-
# Shared thread pool for background training
|
| 34 |
_executor = ThreadPoolExecutor(max_workers=2)
|
| 35 |
|
| 36 |
|
| 37 |
-
# βββ
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
"SVM": lambda: SVC(
|
| 41 |
kernel="rbf", C=10.0, gamma="scale",
|
| 42 |
-
probability=True, class_weight="balanced", random_state=SEED
|
| 43 |
),
|
| 44 |
"LR": lambda: LogisticRegression(
|
| 45 |
C=1.0, max_iter=2000, class_weight="balanced",
|
| 46 |
-
multi_class="multinomial", solver="lbfgs", random_state=SEED
|
| 47 |
-
),
|
| 48 |
-
"RF": lambda: RandomForestClassifier(
|
| 49 |
-
n_estimators=200, max_depth=None,
|
| 50 |
-
class_weight="balanced", random_state=SEED, n_jobs=-1
|
| 51 |
),
|
| 52 |
-
"
|
| 53 |
-
|
| 54 |
-
max_iter=500, early_stopping=True, random_state=SEED
|
| 55 |
),
|
| 56 |
-
"KNN": lambda: KNeighborsClassifier(n_neighbors=5, metric="euclidean"),
|
| 57 |
}
|
| 58 |
|
| 59 |
|
| 60 |
-
# βββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
try:
|
| 73 |
scaler = StandardScaler()
|
| 74 |
-
X_s
|
| 75 |
|
| 76 |
-
#
|
| 77 |
X_tr, X_te, y_tr, y_te = train_test_split(
|
| 78 |
X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
|
| 79 |
)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
acc_t1
|
| 84 |
-
f1_t1
|
| 85 |
-
cm_t1
|
| 86 |
-
|
| 87 |
-
emit(loop, queue, {
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
fold_records: list[dict] = []
|
| 96 |
|
| 97 |
for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
|
| 98 |
flange_out = int(groups[te_idx[0]])
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
acc_f = float(accuracy_score(y[te_idx], y_p))
|
| 104 |
fold_accs.append(acc_f)
|
| 105 |
-
fold_records.append({
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# Final model on all data (for ensemble / CORAL)
|
| 120 |
-
clf_final = SHALLOW_MODELS[model_name]()
|
| 121 |
clf_final.fit(X_s, y)
|
| 122 |
-
|
| 123 |
-
train_acc = float(accuracy_score(y, y_pred_all))
|
| 124 |
-
cm_loio_pooled = confusion_matrix(
|
| 125 |
-
[f["flange_out"] for f in fold_records], # dummy β use actual pooled
|
| 126 |
-
[f["flange_out"] for f in fold_records],
|
| 127 |
-
).tolist()
|
| 128 |
-
|
| 129 |
-
# Pooled LOIO confusion matrix
|
| 130 |
-
all_y_true: list[int] = []
|
| 131 |
-
all_y_pred: list[int] = []
|
| 132 |
-
for tr_idx, te_idx in logo.split(X_s, y, groups):
|
| 133 |
-
clf = SHALLOW_MODELS[model_name]()
|
| 134 |
-
clf.fit(X_s[tr_idx], y[tr_idx])
|
| 135 |
-
all_y_true.extend(y[te_idx].tolist())
|
| 136 |
-
all_y_pred.extend(clf.predict(X_s[te_idx]).tolist())
|
| 137 |
-
cm_t2 = confusion_matrix(all_y_true, all_y_pred, labels=[0, 1, 2]).tolist()
|
| 138 |
-
f1_t2 = float(f1_score(all_y_true, all_y_pred, average="macro"))
|
| 139 |
|
| 140 |
result = {
|
| 141 |
"model": model_name,
|
|
@@ -143,17 +248,90 @@ def _train_shallow(
|
|
| 143 |
"task1_f1": round(f1_t1, 4),
|
| 144 |
"task1_cm": cm_t1,
|
| 145 |
"task2_mean": round(float(np.mean(fold_accs)), 4),
|
| 146 |
-
"task2_std": round(float(np.std(fold_accs)),
|
| 147 |
"task2_f1": round(f1_t2, 4),
|
| 148 |
"task2_cm": cm_t2,
|
| 149 |
"folds": fold_records,
|
| 150 |
"train_acc": round(train_acc, 4),
|
| 151 |
-
"scaler_mean": scaler.mean_.tolist(),
|
| 152 |
-
"scaler_scale": scaler.scale_.tolist(),
|
| 153 |
}
|
| 154 |
session.training_results[model_name] = result
|
| 155 |
session.touch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
emit(loop, queue, {"type": "model_done", "model": model_name, **result})
|
| 158 |
|
| 159 |
except Exception as e:
|
|
@@ -161,15 +339,167 @@ def _train_shallow(
|
|
| 161 |
traceback.print_exc()
|
| 162 |
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
loop = asyncio.new_event_loop()
|
| 167 |
asyncio.set_event_loop(loop)
|
| 168 |
|
| 169 |
session = session_manager.get(session_id)
|
| 170 |
if session is None:
|
| 171 |
return
|
| 172 |
-
|
| 173 |
queue = ws_manager.get_queue(task_id)
|
| 174 |
if queue is None:
|
| 175 |
return
|
|
@@ -179,57 +509,57 @@ def _train_all_models(task_id: str, session_id: str, models: list[str]):
|
|
| 179 |
emit(loop, queue, {"type": "error", "message": "Features not extracted yet"})
|
| 180 |
return
|
| 181 |
|
| 182 |
-
X = np.array(feats["X_feat"],
|
| 183 |
-
y = np.array(feats["labels"],
|
| 184 |
-
groups = np.array(feats["flange_groups"],
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
emit(loop, queue, {"type": "all_done", "task_id": task_id})
|
| 191 |
|
| 192 |
|
| 193 |
-
# βββ
|
|
|
|
|
|
|
| 194 |
|
| 195 |
@router.post("/api/train")
|
| 196 |
async def start_training(
|
| 197 |
session_id: str = Header(..., alias="X-Session-Id"),
|
| 198 |
body: dict = None,
|
| 199 |
):
|
| 200 |
-
"""
|
| 201 |
-
Launch background training. Returns task_id for WebSocket connection.
|
| 202 |
-
Body: {"models": ["SVM", "LR", "RF", "MLP", "KNN"]}
|
| 203 |
-
"""
|
| 204 |
session = session_manager.get(session_id)
|
| 205 |
if session is None:
|
| 206 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 207 |
if not session.features:
|
| 208 |
raise HTTPException(status_code=400, detail="Extract features first: POST /api/features")
|
| 209 |
|
| 210 |
-
models
|
| 211 |
-
|
| 212 |
task_id = str(uuid.uuid4())
|
| 213 |
-
# Create queue before starting thread (thread will use it immediately)
|
| 214 |
ws_manager.create_queue(task_id)
|
| 215 |
session.training_tasks[task_id] = models
|
| 216 |
session.touch()
|
| 217 |
|
| 218 |
loop = asyncio.get_event_loop()
|
| 219 |
-
loop.run_in_executor(
|
| 220 |
-
_executor,
|
| 221 |
-
_train_all_models,
|
| 222 |
-
task_id,
|
| 223 |
-
session_id,
|
| 224 |
-
models,
|
| 225 |
-
)
|
| 226 |
|
| 227 |
return {"task_id": task_id, "models": models}
|
| 228 |
|
| 229 |
|
| 230 |
@router.get("/api/results")
|
| 231 |
async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
|
| 232 |
-
"""Return all completed training results for this session."""
|
| 233 |
session = session_manager.get(session_id)
|
| 234 |
if session is None:
|
| 235 |
raise HTTPException(status_code=404, detail="Session not found")
|
|
@@ -239,11 +569,8 @@ async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
|
|
| 239 |
}
|
| 240 |
|
| 241 |
|
| 242 |
-
# βββ WebSocket endpoint βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
-
|
| 244 |
@router.websocket("/ws/train/{task_id}")
|
| 245 |
async def training_websocket(websocket: WebSocket, task_id: str):
|
| 246 |
-
"""Stream live training events for a given task_id."""
|
| 247 |
await ws_manager.connect(task_id, websocket)
|
| 248 |
try:
|
| 249 |
await ws_manager.stream(task_id, websocket)
|
|
|
|
| 1 |
"""
|
| 2 |
Router: POST /api/train β launch training as a background task
|
| 3 |
GET /api/results β fetch completed results
|
| 4 |
+
WS /ws/train/{task_id} β stream live epoch/fold metrics
|
| 5 |
|
| 6 |
+
Models (matching final_project_saurav_silwal.ipynb exactly):
|
| 7 |
+
Shallow (82-dim tabular features): SVM, LR, KNN
|
| 8 |
+
Deep (82-dim tabular): MLP (Keras, 3 hidden layers + dropout)
|
| 9 |
+
Deep (mel spectrogram): CNN, BiLSTM (Keras)
|
| 10 |
+
|
| 11 |
+
All models: Task 1 (70/30 dependent split) + Task 2 (LOIO cross-validation)
|
| 12 |
"""
|
| 13 |
|
| 14 |
import asyncio
|
| 15 |
import uuid
|
|
|
|
| 16 |
import traceback
|
| 17 |
from concurrent.futures import ThreadPoolExecutor
|
| 18 |
|
|
|
|
| 22 |
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
| 23 |
from sklearn.svm import SVC
|
| 24 |
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
|
|
| 25 |
from sklearn.neighbors import KNeighborsClassifier
|
| 26 |
from fastapi import APIRouter, Header, HTTPException, WebSocket, WebSocketDisconnect
|
| 27 |
|
| 28 |
from session import session_manager
|
| 29 |
from ws.training_ws import ws_manager, emit
|
| 30 |
+
from ml.feature_extraction import (
|
| 31 |
+
extract_mel_spectrogram, impute_nans, FEATURE_NAMES
|
| 32 |
+
)
|
| 33 |
from config import SEED, TEST_SIZE, IDX_TO_CLASS, CLASS_NAMES, N_CLASSES
|
| 34 |
|
| 35 |
+
router = APIRouter(tags=["training"])
|
|
|
|
|
|
|
| 36 |
_executor = ThreadPoolExecutor(max_workers=2)
|
| 37 |
|
| 38 |
|
| 39 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
# Shallow model factories (sklearn)
|
| 41 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
|
| 43 |
+
SHALLOW_FACTORIES = {
|
| 44 |
"SVM": lambda: SVC(
|
| 45 |
kernel="rbf", C=10.0, gamma="scale",
|
| 46 |
+
probability=True, class_weight="balanced", random_state=SEED,
|
| 47 |
),
|
| 48 |
"LR": lambda: LogisticRegression(
|
| 49 |
C=1.0, max_iter=2000, class_weight="balanced",
|
| 50 |
+
multi_class="multinomial", solver="lbfgs", random_state=SEED,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
),
|
| 52 |
+
"KNN": lambda: KNeighborsClassifier(
|
| 53 |
+
n_neighbors=5, metric="euclidean", weights="uniform",
|
|
|
|
| 54 |
),
|
|
|
|
| 55 |
}
|
| 56 |
|
| 57 |
|
| 58 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
# Keras model builders
|
| 60 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
|
| 62 |
+
def build_mlp(input_dim: int, n_classes: int = 3):
|
| 63 |
+
"""3-layer MLP with BatchNorm + Dropout. Matches notebook Cell 9."""
|
| 64 |
+
import tensorflow as tf
|
| 65 |
+
from tensorflow import keras
|
| 66 |
+
|
| 67 |
+
model = keras.Sequential([
|
| 68 |
+
keras.layers.Input(shape=(input_dim,)),
|
| 69 |
+
keras.layers.Dense(256),
|
| 70 |
+
keras.layers.BatchNormalization(),
|
| 71 |
+
keras.layers.Activation("relu"),
|
| 72 |
+
keras.layers.Dropout(0.4),
|
| 73 |
+
|
| 74 |
+
keras.layers.Dense(128),
|
| 75 |
+
keras.layers.BatchNormalization(),
|
| 76 |
+
keras.layers.Activation("relu"),
|
| 77 |
+
keras.layers.Dropout(0.3),
|
| 78 |
+
|
| 79 |
+
keras.layers.Dense(64),
|
| 80 |
+
keras.layers.BatchNormalization(),
|
| 81 |
+
keras.layers.Activation("relu"),
|
| 82 |
+
keras.layers.Dropout(0.2),
|
| 83 |
+
|
| 84 |
+
keras.layers.Dense(n_classes, activation="softmax"),
|
| 85 |
+
])
|
| 86 |
+
model.compile(
|
| 87 |
+
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
|
| 88 |
+
loss="sparse_categorical_crossentropy",
|
| 89 |
+
metrics=["accuracy"],
|
| 90 |
+
)
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def build_cnn(n_mels: int = 64, n_frames: int = 128, n_classes: int = 3):
|
| 95 |
+
"""CNN on log-mel spectrogram (64Γ128Γ1). Matches notebook Cell 10."""
|
| 96 |
+
import tensorflow as tf
|
| 97 |
+
from tensorflow import keras
|
| 98 |
+
|
| 99 |
+
model = keras.Sequential([
|
| 100 |
+
keras.layers.Input(shape=(n_mels, n_frames, 1)),
|
| 101 |
+
|
| 102 |
+
keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu"),
|
| 103 |
+
keras.layers.BatchNormalization(),
|
| 104 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 105 |
+
keras.layers.Dropout(0.25),
|
| 106 |
+
|
| 107 |
+
keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu"),
|
| 108 |
+
keras.layers.BatchNormalization(),
|
| 109 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 110 |
+
keras.layers.Dropout(0.25),
|
| 111 |
+
|
| 112 |
+
keras.layers.Conv2D(128, (3, 3), padding="same", activation="relu"),
|
| 113 |
+
keras.layers.BatchNormalization(),
|
| 114 |
+
keras.layers.GlobalAveragePooling2D(),
|
| 115 |
+
keras.layers.Dropout(0.4),
|
| 116 |
+
|
| 117 |
+
keras.layers.Dense(128, activation="relu"),
|
| 118 |
+
keras.layers.Dropout(0.3),
|
| 119 |
+
keras.layers.Dense(n_classes, activation="softmax"),
|
| 120 |
+
])
|
| 121 |
+
model.compile(
|
| 122 |
+
optimizer="adam",
|
| 123 |
+
loss="sparse_categorical_crossentropy",
|
| 124 |
+
metrics=["accuracy"],
|
| 125 |
+
)
|
| 126 |
+
return model
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def build_bilstm(n_frames: int = 128, n_mels: int = 64, n_classes: int = 3):
|
| 130 |
+
"""Bidirectional LSTM on mel sequences (128 time steps Γ 64 mel features).
|
| 131 |
+
Matches notebook Cell 11."""
|
| 132 |
+
import tensorflow as tf
|
| 133 |
+
from tensorflow import keras
|
| 134 |
+
|
| 135 |
+
model = keras.Sequential([
|
| 136 |
+
keras.layers.Input(shape=(n_frames, n_mels)),
|
| 137 |
+
keras.layers.Bidirectional(keras.layers.LSTM(64, return_sequences=True)),
|
| 138 |
+
keras.layers.Dropout(0.3),
|
| 139 |
+
keras.layers.Bidirectional(keras.layers.LSTM(32)),
|
| 140 |
+
keras.layers.Dropout(0.3),
|
| 141 |
+
keras.layers.Dense(64, activation="relu"),
|
| 142 |
+
keras.layers.Dropout(0.2),
|
| 143 |
+
keras.layers.Dense(n_classes, activation="softmax"),
|
| 144 |
+
])
|
| 145 |
+
model.compile(
|
| 146 |
+
optimizer="adam",
|
| 147 |
+
loss="sparse_categorical_crossentropy",
|
| 148 |
+
metrics=["accuracy"],
|
| 149 |
+
)
|
| 150 |
+
return model
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
# WebSocket epoch callback for Keras
|
| 155 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
+
|
| 157 |
+
class WSCallback:
|
| 158 |
+
"""Keras callback that emits epoch metrics over WebSocket."""
|
| 159 |
+
|
| 160 |
+
def __init__(self, loop, queue, model_name, total_epochs):
|
| 161 |
+
self.loop = loop
|
| 162 |
+
self.queue = queue
|
| 163 |
+
self.model_name = model_name
|
| 164 |
+
self.total_epochs = total_epochs
|
| 165 |
+
|
| 166 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 167 |
+
logs = logs or {}
|
| 168 |
+
emit(self.loop, self.queue, {
|
| 169 |
+
"type": "epoch",
|
| 170 |
+
"model": self.model_name,
|
| 171 |
+
"epoch": epoch + 1,
|
| 172 |
+
"total": self.total_epochs,
|
| 173 |
+
"train_acc": round(float(logs.get("accuracy", 0)), 4),
|
| 174 |
+
"val_acc": round(float(logs.get("val_accuracy", 0)), 4),
|
| 175 |
+
"train_loss": round(float(logs.get("loss", 0)), 4),
|
| 176 |
+
"val_loss": round(float(logs.get("val_loss", 0)), 4),
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _make_keras_callback(loop, queue, model_name, total_epochs):
|
| 181 |
+
"""Return a tf.keras.callbacks.Callback subclass instance."""
|
| 182 |
+
import tensorflow as tf
|
| 183 |
|
| 184 |
+
cb = WSCallback(loop, queue, model_name, total_epochs)
|
| 185 |
+
|
| 186 |
+
class _CB(tf.keras.callbacks.Callback):
|
| 187 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 188 |
+
cb.on_epoch_end(epoch, logs)
|
| 189 |
+
|
| 190 |
+
return _CB()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
+
# Shallow model training (SVM / LR / KNN)
|
| 195 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 196 |
+
|
| 197 |
+
def _train_shallow(model_name, X, y, groups, loop, queue, session):
|
| 198 |
try:
|
| 199 |
scaler = StandardScaler()
|
| 200 |
+
X_s = scaler.fit_transform(X)
|
| 201 |
|
| 202 |
+
# Task 1
|
| 203 |
X_tr, X_te, y_tr, y_te = train_test_split(
|
| 204 |
X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
|
| 205 |
)
|
| 206 |
+
clf = SHALLOW_FACTORIES[model_name]()
|
| 207 |
+
clf.fit(X_tr, y_tr)
|
| 208 |
+
y_p1 = clf.predict(X_te)
|
| 209 |
+
acc_t1 = float(accuracy_score(y_te, y_p1))
|
| 210 |
+
f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
|
| 211 |
+
cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
|
| 212 |
+
|
| 213 |
+
emit(loop, queue, {"type": "task1_done", "model": model_name,
|
| 214 |
+
"acc": round(acc_t1, 4), "f1": round(f1_t1, 4)})
|
| 215 |
+
|
| 216 |
+
# Task 2 β LOIO
|
| 217 |
+
logo = LeaveOneGroupOut()
|
| 218 |
+
fold_accs = []
|
| 219 |
+
fold_records = []
|
| 220 |
+
all_yt, all_yp = [], []
|
|
|
|
| 221 |
|
| 222 |
for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
|
| 223 |
flange_out = int(groups[te_idx[0]])
|
| 224 |
+
clf2 = SHALLOW_FACTORIES[model_name]()
|
| 225 |
+
clf2.fit(X_s[tr_idx], y[tr_idx])
|
| 226 |
+
yp = clf2.predict(X_s[te_idx])
|
| 227 |
+
acc_f = float(accuracy_score(y[te_idx], yp))
|
|
|
|
| 228 |
fold_accs.append(acc_f)
|
| 229 |
+
fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
|
| 230 |
+
"acc": round(acc_f, 4), "n_test": len(te_idx)})
|
| 231 |
+
all_yt.extend(y[te_idx].tolist())
|
| 232 |
+
all_yp.extend(yp.tolist())
|
| 233 |
+
emit(loop, queue, {"type": "fold_done", "model": model_name,
|
| 234 |
+
"fold": fold_i + 1, "flange_out": flange_out,
|
| 235 |
+
"acc": round(acc_f, 4)})
|
| 236 |
+
|
| 237 |
+
cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
|
| 238 |
+
f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
|
| 239 |
+
|
| 240 |
+
# Final model on all data
|
| 241 |
+
clf_final = SHALLOW_FACTORIES[model_name]()
|
|
|
|
|
|
|
|
|
|
| 242 |
clf_final.fit(X_s, y)
|
| 243 |
+
train_acc = float(accuracy_score(y, clf_final.predict(X_s)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
result = {
|
| 246 |
"model": model_name,
|
|
|
|
| 248 |
"task1_f1": round(f1_t1, 4),
|
| 249 |
"task1_cm": cm_t1,
|
| 250 |
"task2_mean": round(float(np.mean(fold_accs)), 4),
|
| 251 |
+
"task2_std": round(float(np.std(fold_accs)), 4),
|
| 252 |
"task2_f1": round(f1_t2, 4),
|
| 253 |
"task2_cm": cm_t2,
|
| 254 |
"folds": fold_records,
|
| 255 |
"train_acc": round(train_acc, 4),
|
|
|
|
|
|
|
| 256 |
}
|
| 257 |
session.training_results[model_name] = result
|
| 258 |
session.touch()
|
| 259 |
+
emit(loop, queue, {"type": "model_done", "model": model_name, **result})
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
|
| 263 |
+
traceback.print_exc()
|
| 264 |
|
| 265 |
+
|
| 266 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
+
# MLP training (Keras, tabular features)
|
| 268 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
+
|
| 270 |
+
def _train_mlp(X, y, groups, loop, queue, session, epochs=60):
|
| 271 |
+
model_name = "MLP"
|
| 272 |
+
try:
|
| 273 |
+
import tensorflow as tf
|
| 274 |
+
tf.random.set_seed(SEED)
|
| 275 |
+
|
| 276 |
+
scaler = StandardScaler()
|
| 277 |
+
X_s = scaler.fit_transform(X)
|
| 278 |
+
|
| 279 |
+
# Task 1
|
| 280 |
+
X_tr, X_te, y_tr, y_te = train_test_split(
|
| 281 |
+
X_s, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
|
| 282 |
+
)
|
| 283 |
+
model = build_mlp(X_s.shape[1])
|
| 284 |
+
cb = _make_keras_callback(loop, queue, model_name, epochs)
|
| 285 |
+
es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
|
| 286 |
+
model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
|
| 287 |
+
validation_split=0.15,
|
| 288 |
+
callbacks=[cb, es], verbose=0)
|
| 289 |
+
|
| 290 |
+
y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
|
| 291 |
+
acc_t1 = float(accuracy_score(y_te, y_p1))
|
| 292 |
+
f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
|
| 293 |
+
cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
|
| 294 |
+
|
| 295 |
+
emit(loop, queue, {"type": "task1_done", "model": model_name,
|
| 296 |
+
"acc": round(acc_t1, 4), "f1": round(f1_t1, 4)})
|
| 297 |
+
|
| 298 |
+
# Task 2 β LOIO
|
| 299 |
+
logo = LeaveOneGroupOut()
|
| 300 |
+
fold_accs = []
|
| 301 |
+
fold_records = []
|
| 302 |
+
all_yt, all_yp = [], []
|
| 303 |
+
|
| 304 |
+
for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_s, y, groups)):
|
| 305 |
+
flange_out = int(groups[te_idx[0]])
|
| 306 |
+
m2 = build_mlp(X_s.shape[1])
|
| 307 |
+
es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
|
| 308 |
+
m2.fit(X_s[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
|
| 309 |
+
validation_split=0.15, callbacks=[es2], verbose=0)
|
| 310 |
+
yp = np.argmax(m2.predict(X_s[te_idx], verbose=0), axis=1)
|
| 311 |
+
acc_f = float(accuracy_score(y[te_idx], yp))
|
| 312 |
+
fold_accs.append(acc_f)
|
| 313 |
+
fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
|
| 314 |
+
"acc": round(acc_f, 4), "n_test": len(te_idx)})
|
| 315 |
+
all_yt.extend(y[te_idx].tolist())
|
| 316 |
+
all_yp.extend(yp.tolist())
|
| 317 |
+
emit(loop, queue, {"type": "fold_done", "model": model_name,
|
| 318 |
+
"fold": fold_i + 1, "flange_out": flange_out,
|
| 319 |
+
"acc": round(acc_f, 4)})
|
| 320 |
+
|
| 321 |
+
cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
|
| 322 |
+
f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
|
| 323 |
+
train_acc = float(accuracy_score(y, np.argmax(model.predict(X_s, verbose=0), axis=1)))
|
| 324 |
+
|
| 325 |
+
result = {
|
| 326 |
+
"model": model_name, "task1_acc": round(acc_t1, 4),
|
| 327 |
+
"task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
|
| 328 |
+
"task2_mean": round(float(np.mean(fold_accs)), 4),
|
| 329 |
+
"task2_std": round(float(np.std(fold_accs)), 4),
|
| 330 |
+
"task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
|
| 331 |
+
"folds": fold_records, "train_acc": round(train_acc, 4),
|
| 332 |
+
}
|
| 333 |
+
session.training_results[model_name] = result
|
| 334 |
+
session.touch()
|
| 335 |
emit(loop, queue, {"type": "model_done", "model": model_name, **result})
|
| 336 |
|
| 337 |
except Exception as e:
|
|
|
|
| 339 |
traceback.print_exc()
|
| 340 |
|
| 341 |
|
| 342 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 343 |
+
# CNN training (Keras, mel spectrograms)
|
| 344 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 345 |
+
|
| 346 |
+
def _train_cnn(waveforms, y, groups, loop, queue, session, epochs=50):
|
| 347 |
+
model_name = "CNN"
|
| 348 |
+
try:
|
| 349 |
+
import tensorflow as tf
|
| 350 |
+
tf.random.set_seed(SEED)
|
| 351 |
+
|
| 352 |
+
# Build spectrogram tensor (N, 64, 128, 1)
|
| 353 |
+
emit(loop, queue, {"type": "task1_done", "model": model_name,
|
| 354 |
+
"acc": 0.0, "f1": 0.0, "message": "Extracting spectrograms..."})
|
| 355 |
+
X_spec = np.stack([
|
| 356 |
+
extract_mel_spectrogram(np.array(w, dtype=np.float32))
|
| 357 |
+
for w in waveforms
|
| 358 |
+
], axis=0)[..., np.newaxis] # (N, 64, 128, 1)
|
| 359 |
+
|
| 360 |
+
# Task 1
|
| 361 |
+
X_tr, X_te, y_tr, y_te = train_test_split(
|
| 362 |
+
X_spec, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
|
| 363 |
+
)
|
| 364 |
+
model = build_cnn()
|
| 365 |
+
cb = _make_keras_callback(loop, queue, model_name, epochs)
|
| 366 |
+
es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
|
| 367 |
+
model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
|
| 368 |
+
validation_split=0.15, callbacks=[cb, es], verbose=0)
|
| 369 |
+
|
| 370 |
+
y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
|
| 371 |
+
acc_t1 = float(accuracy_score(y_te, y_p1))
|
| 372 |
+
f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
|
| 373 |
+
cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
|
| 374 |
+
|
| 375 |
+
# Task 2 β LOIO
|
| 376 |
+
logo = LeaveOneGroupOut()
|
| 377 |
+
fold_accs, fold_records, all_yt, all_yp = [], [], [], []
|
| 378 |
+
|
| 379 |
+
for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_spec, y, groups)):
|
| 380 |
+
flange_out = int(groups[te_idx[0]])
|
| 381 |
+
m2 = build_cnn()
|
| 382 |
+
es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
|
| 383 |
+
m2.fit(X_spec[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
|
| 384 |
+
validation_split=0.15, callbacks=[es2], verbose=0)
|
| 385 |
+
yp = np.argmax(m2.predict(X_spec[te_idx], verbose=0), axis=1)
|
| 386 |
+
acc_f = float(accuracy_score(y[te_idx], yp))
|
| 387 |
+
fold_accs.append(acc_f)
|
| 388 |
+
fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
|
| 389 |
+
"acc": round(acc_f, 4), "n_test": len(te_idx)})
|
| 390 |
+
all_yt.extend(y[te_idx].tolist())
|
| 391 |
+
all_yp.extend(yp.tolist())
|
| 392 |
+
emit(loop, queue, {"type": "fold_done", "model": model_name,
|
| 393 |
+
"fold": fold_i + 1, "flange_out": flange_out,
|
| 394 |
+
"acc": round(acc_f, 4)})
|
| 395 |
+
|
| 396 |
+
cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
|
| 397 |
+
f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
|
| 398 |
+
train_acc = float(accuracy_score(y, np.argmax(model.predict(X_spec, verbose=0), axis=1)))
|
| 399 |
+
|
| 400 |
+
result = {
|
| 401 |
+
"model": model_name, "task1_acc": round(acc_t1, 4),
|
| 402 |
+
"task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
|
| 403 |
+
"task2_mean": round(float(np.mean(fold_accs)), 4),
|
| 404 |
+
"task2_std": round(float(np.std(fold_accs)), 4),
|
| 405 |
+
"task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
|
| 406 |
+
"folds": fold_records, "train_acc": round(train_acc, 4),
|
| 407 |
+
}
|
| 408 |
+
session.training_results[model_name] = result
|
| 409 |
+
session.touch()
|
| 410 |
+
emit(loop, queue, {"type": "model_done", "model": model_name, **result})
|
| 411 |
+
|
| 412 |
+
except Exception as e:
|
| 413 |
+
emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
|
| 414 |
+
traceback.print_exc()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 418 |
+
# BiLSTM training (Keras, mel sequences)
|
| 419 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 420 |
+
|
| 421 |
+
def _train_bilstm(waveforms, y, groups, loop, queue, session, epochs=50):
|
| 422 |
+
model_name = "BiLSTM"
|
| 423 |
+
try:
|
| 424 |
+
import tensorflow as tf
|
| 425 |
+
tf.random.set_seed(SEED)
|
| 426 |
+
|
| 427 |
+
# Reshape to (N, 128, 64) β time steps Γ mel features
|
| 428 |
+
X_seq = np.stack([
|
| 429 |
+
extract_mel_spectrogram(np.array(w, dtype=np.float32)).T # (128, 64)
|
| 430 |
+
for w in waveforms
|
| 431 |
+
], axis=0)
|
| 432 |
+
|
| 433 |
+
# Task 1
|
| 434 |
+
X_tr, X_te, y_tr, y_te = train_test_split(
|
| 435 |
+
X_seq, y, test_size=TEST_SIZE, stratify=y, random_state=SEED
|
| 436 |
+
)
|
| 437 |
+
model = build_bilstm()
|
| 438 |
+
cb = _make_keras_callback(loop, queue, model_name, epochs)
|
| 439 |
+
es = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
|
| 440 |
+
model.fit(X_tr, y_tr, epochs=epochs, batch_size=32,
|
| 441 |
+
validation_split=0.15, callbacks=[cb, es], verbose=0)
|
| 442 |
+
|
| 443 |
+
y_p1 = np.argmax(model.predict(X_te, verbose=0), axis=1)
|
| 444 |
+
acc_t1 = float(accuracy_score(y_te, y_p1))
|
| 445 |
+
f1_t1 = float(f1_score(y_te, y_p1, average="macro"))
|
| 446 |
+
cm_t1 = confusion_matrix(y_te, y_p1, labels=[0, 1, 2]).tolist()
|
| 447 |
+
|
| 448 |
+
# Task 2 β LOIO
|
| 449 |
+
logo = LeaveOneGroupOut()
|
| 450 |
+
fold_accs, fold_records, all_yt, all_yp = [], [], [], []
|
| 451 |
+
|
| 452 |
+
for fold_i, (tr_idx, te_idx) in enumerate(logo.split(X_seq, y, groups)):
|
| 453 |
+
flange_out = int(groups[te_idx[0]])
|
| 454 |
+
m2 = build_bilstm()
|
| 455 |
+
es2 = tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
|
| 456 |
+
m2.fit(X_seq[tr_idx], y[tr_idx], epochs=epochs, batch_size=32,
|
| 457 |
+
validation_split=0.15, callbacks=[es2], verbose=0)
|
| 458 |
+
yp = np.argmax(m2.predict(X_seq[te_idx], verbose=0), axis=1)
|
| 459 |
+
acc_f = float(accuracy_score(y[te_idx], yp))
|
| 460 |
+
fold_accs.append(acc_f)
|
| 461 |
+
fold_records.append({"fold": fold_i + 1, "flange_out": flange_out,
|
| 462 |
+
"acc": round(acc_f, 4), "n_test": len(te_idx)})
|
| 463 |
+
all_yt.extend(y[te_idx].tolist())
|
| 464 |
+
all_yp.extend(yp.tolist())
|
| 465 |
+
emit(loop, queue, {"type": "fold_done", "model": model_name,
|
| 466 |
+
"fold": fold_i + 1, "flange_out": flange_out,
|
| 467 |
+
"acc": round(acc_f, 4)})
|
| 468 |
+
|
| 469 |
+
cm_t2 = confusion_matrix(all_yt, all_yp, labels=[0, 1, 2]).tolist()
|
| 470 |
+
f1_t2 = float(f1_score(all_yt, all_yp, average="macro"))
|
| 471 |
+
train_acc = float(accuracy_score(y, np.argmax(model.predict(X_seq, verbose=0), axis=1)))
|
| 472 |
+
|
| 473 |
+
result = {
|
| 474 |
+
"model": model_name, "task1_acc": round(acc_t1, 4),
|
| 475 |
+
"task1_f1": round(f1_t1, 4), "task1_cm": cm_t1,
|
| 476 |
+
"task2_mean": round(float(np.mean(fold_accs)), 4),
|
| 477 |
+
"task2_std": round(float(np.std(fold_accs)), 4),
|
| 478 |
+
"task2_f1": round(f1_t2, 4), "task2_cm": cm_t2,
|
| 479 |
+
"folds": fold_records, "train_acc": round(train_acc, 4),
|
| 480 |
+
}
|
| 481 |
+
session.training_results[model_name] = result
|
| 482 |
+
session.touch()
|
| 483 |
+
emit(loop, queue, {"type": "model_done", "model": model_name, **result})
|
| 484 |
+
|
| 485 |
+
except Exception as e:
|
| 486 |
+
emit(loop, queue, {"type": "error", "model": model_name, "message": str(e)})
|
| 487 |
+
traceback.print_exc()
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 491 |
+
# Master training thread
|
| 492 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 493 |
+
|
| 494 |
+
ALL_MODELS = ["SVM", "LR", "KNN", "MLP", "CNN", "BiLSTM"]
|
| 495 |
+
|
| 496 |
+
def _train_all(task_id: str, session_id: str, models: list[str]):
|
| 497 |
loop = asyncio.new_event_loop()
|
| 498 |
asyncio.set_event_loop(loop)
|
| 499 |
|
| 500 |
session = session_manager.get(session_id)
|
| 501 |
if session is None:
|
| 502 |
return
|
|
|
|
| 503 |
queue = ws_manager.get_queue(task_id)
|
| 504 |
if queue is None:
|
| 505 |
return
|
|
|
|
| 509 |
emit(loop, queue, {"type": "error", "message": "Features not extracted yet"})
|
| 510 |
return
|
| 511 |
|
| 512 |
+
X = np.array(feats["X_feat"], dtype=np.float32)
|
| 513 |
+
y = np.array(feats["labels"], dtype=np.int64)
|
| 514 |
+
groups = np.array(feats["flange_groups"], dtype=np.int64)
|
| 515 |
|
| 516 |
+
# NaN imputation (tau column can have NaNs)
|
| 517 |
+
X = impute_nans(X, y)
|
| 518 |
+
|
| 519 |
+
waveforms = session.hits.get("waveforms", [])
|
| 520 |
+
|
| 521 |
+
for m in models:
|
| 522 |
+
if m in SHALLOW_FACTORIES:
|
| 523 |
+
_train_shallow(m, X, y, groups, loop, queue, session)
|
| 524 |
+
elif m == "MLP":
|
| 525 |
+
_train_mlp(X, y, groups, loop, queue, session)
|
| 526 |
+
elif m == "CNN":
|
| 527 |
+
_train_cnn(waveforms, y, groups, loop, queue, session)
|
| 528 |
+
elif m == "BiLSTM":
|
| 529 |
+
_train_bilstm(waveforms, y, groups, loop, queue, session)
|
| 530 |
|
| 531 |
emit(loop, queue, {"type": "all_done", "task_id": task_id})
|
| 532 |
|
| 533 |
|
| 534 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 535 |
+
# Routes
|
| 536 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 537 |
|
| 538 |
@router.post("/api/train")
|
| 539 |
async def start_training(
|
| 540 |
session_id: str = Header(..., alias="X-Session-Id"),
|
| 541 |
body: dict = None,
|
| 542 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
session = session_manager.get(session_id)
|
| 544 |
if session is None:
|
| 545 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 546 |
if not session.features:
|
| 547 |
raise HTTPException(status_code=400, detail="Extract features first: POST /api/features")
|
| 548 |
|
| 549 |
+
models = (body or {}).get("models", ALL_MODELS)
|
|
|
|
| 550 |
task_id = str(uuid.uuid4())
|
|
|
|
| 551 |
ws_manager.create_queue(task_id)
|
| 552 |
session.training_tasks[task_id] = models
|
| 553 |
session.touch()
|
| 554 |
|
| 555 |
loop = asyncio.get_event_loop()
|
| 556 |
+
loop.run_in_executor(_executor, _train_all, task_id, session_id, models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
return {"task_id": task_id, "models": models}
|
| 559 |
|
| 560 |
|
| 561 |
@router.get("/api/results")
|
| 562 |
async def get_results(session_id: str = Header(..., alias="X-Session-Id")):
|
|
|
|
| 563 |
session = session_manager.get(session_id)
|
| 564 |
if session is None:
|
| 565 |
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
| 569 |
}
|
| 570 |
|
| 571 |
|
|
|
|
|
|
|
| 572 |
@router.websocket("/ws/train/{task_id}")
|
| 573 |
async def training_websocket(websocket: WebSocket, task_id: str):
|
|
|
|
| 574 |
await ws_manager.connect(task_id, websocket)
|
| 575 |
try:
|
| 576 |
await ws_manager.stream(task_id, websocket)
|