File size: 5,776 Bytes
2c11783
 
 
 
 
 
 
 
 
 
 
 
f4a70f9
 
2c11783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Router: /api/features
Extracts 82-dimensional feature vectors from all hits and computes
mel spectrograms. Returns feature data + 2D PCA for visualisation.
"""

import numpy as np
from fastapi import APIRouter, Header, HTTPException

from session import session_manager
from ml.feature_extraction import (
    extract_features, extract_mel_spectrogram,
    relative_psd_log_bins, mfcc_stats, decay_tau,
    energy_ratio, impute_nans, FEATURE_NAMES
)
from config import IDX_TO_CLASS, CLASS_NAMES

router = APIRouter(prefix="/api", tags=["features"])


def pca_2d(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, list[float]]:
    """Simple PCA to 2 components (no sklearn needed for basic viz)."""
    X_c = X - X.mean(axis=0)
    cov = np.cov(X_c.T)
    eigvals, eigvecs = np.linalg.eigh(cov)
    idx = np.argsort(eigvals)[::-1]
    components = eigvecs[:, idx[:2]]
    X_pca = X_c @ components
    var_ratio = eigvals[idx[:2]] / eigvals.sum()
    return X_pca, components, var_ratio.tolist()


@router.post("/features")
async def extract_all_features(session_id: str = Header(..., alias="X-Session-Id")):
    """
    Run feature extraction on all hits stored in session.
    Stores X_feat, feature_names, and PCA coords in session.
    Returns summary + PCA scatter data for the frontend.
    """
    session = session_manager.get(session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found or expired")

    hits = session.hits
    if not hits or hits["n_hits"] == 0:
        raise HTTPException(status_code=400, detail="No hits found. Run /api/process first.")

    waveforms     = hits["waveforms"]
    labels        = np.array(hits["labels"])
    flange_groups = np.array(hits["flange_groups"])

    # Extract features for every hit
    X_feat = np.stack([
        extract_features(np.array(w, dtype=np.float32))
        for w in waveforms
    ], axis=0)   # (N, 82)

    # PCA for scatter plot
    X_pca, components, var_ratio = pca_2d(X_feat)

    # Per-class mean feature profile (for bar chart comparison)
    class_profiles = {}
    for idx in [0, 1, 2]:
        mask = labels == idx
        if mask.any():
            class_profiles[str(IDX_TO_CLASS[idx])] = X_feat[mask].mean(axis=0).tolist()

    # Store in session
    session.features = {
        "X_feat":         X_feat.tolist(),
        "feature_names":  FEATURE_NAMES,
        "labels":         labels.tolist(),
        "flange_groups":  flange_groups.tolist(),
        "X_pca":          X_pca.tolist(),
        "pca_var_ratio":  var_ratio,
        "class_profiles": class_profiles,
        "n_features":     X_feat.shape[1],
        "n_hits":         X_feat.shape[0],
    }
    session.touch()

    # Build scatter data: list of {x, y, label_idx, label_name, flange}
    scatter = [
        {
            "x":          round(float(X_pca[i, 0]), 4),
            "y":          round(float(X_pca[i, 1]), 4),
            "label_idx":  int(labels[i]),
            "label_name": CLASS_NAMES[int(labels[i])],
            "flange":     int(flange_groups[i]),
        }
        for i in range(len(labels))
    ]

    return {
        "status":        "done",
        "n_hits":        int(X_feat.shape[0]),
        "n_features":    int(X_feat.shape[1]),
        "feature_names": FEATURE_NAMES,
        "scatter":       scatter,
        "pca_var_ratio": var_ratio,
        "class_profiles": class_profiles,
        "feature_stats": {
            name: {
                "mean": round(float(X_feat[:, i].mean()), 4),
                "std":  round(float(X_feat[:, i].std()),  4),
                "min":  round(float(X_feat[:, i].min()),  4),
                "max":  round(float(X_feat[:, i].max()),  4),
            }
            for i, name in enumerate(FEATURE_NAMES)
        },
    }


@router.get("/features/hit/{hit_idx}")
async def get_hit_features(
    hit_idx: int,
    session_id: str = Header(..., alias="X-Session-Id"),
):
    """
    Return all feature visualisation data for a single hit:
    waveform, mel spectrogram, PSD, MFCC, decay curve, energy ratio.
    Used by the Feature Extraction screen's hit picker.
    """
    session = session_manager.get(session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found")

    hits = session.hits
    if not hits or hit_idx >= hits["n_hits"]:
        raise HTTPException(status_code=404, detail=f"Hit {hit_idx} not found")

    win = np.array(hits["waveforms"][hit_idx], dtype=np.float32)

    # Individual feature components for visualisation
    psd          = extract_psd(win)
    mfcc_m, mfcc_s = extract_mfcc(win)
    tau          = extract_decay(win)
    energy_ratio = extract_energy_ratio(win)
    mel          = extract_mel_spectrogram(win)

    # Decay envelope for chart
    frame_len = int(0.005 * 48000)
    n_frames  = len(win) // frame_len
    rms_env   = [
        float(np.sqrt(np.mean(win[i * frame_len:(i + 1) * frame_len] ** 2)))
        for i in range(n_frames)
    ]
    decay_t = [round(i * 0.005, 4) for i in range(n_frames)]

    # Full feature vector
    feat = extract_features(win)

    return {
        "hit_idx":     hit_idx,
        "label_idx":   hits["labels"][hit_idx],
        "label_name":  CLASS_NAMES[hits["labels"][hit_idx]],
        "flange_id":   hits["flange_groups"][hit_idx],
        "psd":         psd.tolist(),
        "mfcc_mean":   mfcc_m.tolist(),
        "mfcc_std":    mfcc_s.tolist(),
        "tau":         round(float(tau), 4),
        "energy_ratio": round(float(energy_ratio), 4),
        "mel_spectrogram": mel.tolist(),   # (64, 128) — Plotly heatmap
        "decay_rms":   rms_env,
        "decay_t":     decay_t,
        "feature_vector": feat.tolist(),
        "feature_names":  FEATURE_NAMES,
    }