Spaces:
Sleeping
Sleeping
File size: 5,776 Bytes
2c11783 f4a70f9 2c11783 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """
Router: /api/features
Extracts 82-dimensional feature vectors from all hits and computes
mel spectrograms. Returns feature data + 2D PCA for visualisation.
"""
import numpy as np
from fastapi import APIRouter, Header, HTTPException
from session import session_manager
from ml.feature_extraction import (
extract_features, extract_mel_spectrogram,
relative_psd_log_bins, mfcc_stats, decay_tau,
energy_ratio, impute_nans, FEATURE_NAMES
)
from config import IDX_TO_CLASS, CLASS_NAMES
router = APIRouter(prefix="/api", tags=["features"])
def pca_2d(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, list[float]]:
"""Simple PCA to 2 components (no sklearn needed for basic viz)."""
X_c = X - X.mean(axis=0)
cov = np.cov(X_c.T)
eigvals, eigvecs = np.linalg.eigh(cov)
idx = np.argsort(eigvals)[::-1]
components = eigvecs[:, idx[:2]]
X_pca = X_c @ components
var_ratio = eigvals[idx[:2]] / eigvals.sum()
return X_pca, components, var_ratio.tolist()
@router.post("/features")
async def extract_all_features(session_id: str = Header(..., alias="X-Session-Id")):
"""
Run feature extraction on all hits stored in session.
Stores X_feat, feature_names, and PCA coords in session.
Returns summary + PCA scatter data for the frontend.
"""
session = session_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found or expired")
hits = session.hits
if not hits or hits["n_hits"] == 0:
raise HTTPException(status_code=400, detail="No hits found. Run /api/process first.")
waveforms = hits["waveforms"]
labels = np.array(hits["labels"])
flange_groups = np.array(hits["flange_groups"])
# Extract features for every hit
X_feat = np.stack([
extract_features(np.array(w, dtype=np.float32))
for w in waveforms
], axis=0) # (N, 82)
# PCA for scatter plot
X_pca, components, var_ratio = pca_2d(X_feat)
# Per-class mean feature profile (for bar chart comparison)
class_profiles = {}
for idx in [0, 1, 2]:
mask = labels == idx
if mask.any():
class_profiles[str(IDX_TO_CLASS[idx])] = X_feat[mask].mean(axis=0).tolist()
# Store in session
session.features = {
"X_feat": X_feat.tolist(),
"feature_names": FEATURE_NAMES,
"labels": labels.tolist(),
"flange_groups": flange_groups.tolist(),
"X_pca": X_pca.tolist(),
"pca_var_ratio": var_ratio,
"class_profiles": class_profiles,
"n_features": X_feat.shape[1],
"n_hits": X_feat.shape[0],
}
session.touch()
# Build scatter data: list of {x, y, label_idx, label_name, flange}
scatter = [
{
"x": round(float(X_pca[i, 0]), 4),
"y": round(float(X_pca[i, 1]), 4),
"label_idx": int(labels[i]),
"label_name": CLASS_NAMES[int(labels[i])],
"flange": int(flange_groups[i]),
}
for i in range(len(labels))
]
return {
"status": "done",
"n_hits": int(X_feat.shape[0]),
"n_features": int(X_feat.shape[1]),
"feature_names": FEATURE_NAMES,
"scatter": scatter,
"pca_var_ratio": var_ratio,
"class_profiles": class_profiles,
"feature_stats": {
name: {
"mean": round(float(X_feat[:, i].mean()), 4),
"std": round(float(X_feat[:, i].std()), 4),
"min": round(float(X_feat[:, i].min()), 4),
"max": round(float(X_feat[:, i].max()), 4),
}
for i, name in enumerate(FEATURE_NAMES)
},
}
@router.get("/features/hit/{hit_idx}")
async def get_hit_features(
hit_idx: int,
session_id: str = Header(..., alias="X-Session-Id"),
):
"""
Return all feature visualisation data for a single hit:
waveform, mel spectrogram, PSD, MFCC, decay curve, energy ratio.
Used by the Feature Extraction screen's hit picker.
"""
session = session_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found")
hits = session.hits
if not hits or hit_idx >= hits["n_hits"]:
raise HTTPException(status_code=404, detail=f"Hit {hit_idx} not found")
win = np.array(hits["waveforms"][hit_idx], dtype=np.float32)
# Individual feature components for visualisation
psd = extract_psd(win)
mfcc_m, mfcc_s = extract_mfcc(win)
tau = extract_decay(win)
energy_ratio = extract_energy_ratio(win)
mel = extract_mel_spectrogram(win)
# Decay envelope for chart
frame_len = int(0.005 * 48000)
n_frames = len(win) // frame_len
rms_env = [
float(np.sqrt(np.mean(win[i * frame_len:(i + 1) * frame_len] ** 2)))
for i in range(n_frames)
]
decay_t = [round(i * 0.005, 4) for i in range(n_frames)]
# Full feature vector
feat = extract_features(win)
return {
"hit_idx": hit_idx,
"label_idx": hits["labels"][hit_idx],
"label_name": CLASS_NAMES[hits["labels"][hit_idx]],
"flange_id": hits["flange_groups"][hit_idx],
"psd": psd.tolist(),
"mfcc_mean": mfcc_m.tolist(),
"mfcc_std": mfcc_s.tolist(),
"tau": round(float(tau), 4),
"energy_ratio": round(float(energy_ratio), 4),
"mel_spectrogram": mel.tolist(), # (64, 128) — Plotly heatmap
"decay_rms": rms_env,
"decay_t": decay_t,
"feature_vector": feat.tolist(),
"feature_names": FEATURE_NAMES,
}
|