File size: 6,275 Bytes
2c11783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Router: /api/process
Runs the full hit detection pipeline across all uploaded training files.
Returns extracted hit metadata and waveform previews for the frontend.
"""

import numpy as np
from fastapi import APIRouter, Header, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse

from session import session_manager
from utils.audio import extract_hits_from_file, load_audio, get_rms_envelope
from config import SR, HIT_WINDOW_LEN

router = APIRouter(prefix="/api", tags=["process"])

# Downsample factor for waveform preview (send 1 in N samples to frontend)
PREVIEW_DOWNSAMPLE = 64    # 48000 Hz → 750 Hz preview resolution


def downsample(arr: np.ndarray, factor: int) -> list[float]:
    """Return a downsampled version of the array as a Python list."""
    return arr[::factor].tolist()


@router.post("/process")
async def process_files(session_id: str = Header(..., alias="X-Session-Id")):
    """
    Extract hits from all uploaded training files.
    Stores results in session.hits and session.processing_stats.
    Returns per-file stats, class distribution, and waveform preview of first hit.
    """
    session = session_manager.get(session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found or expired")

    if not session.uploaded_files:
        raise HTTPException(status_code=400, detail="No files uploaded. Call POST /api/upload first.")

    all_waveforms:     list[list[float]] = []
    all_labels:        list[int]  = []
    all_flange_groups: list[int]  = []
    all_area_groups:   list[int]  = []
    per_file_stats:    list[dict] = []
    all_quality_logs:  list[dict] = []

    for finfo in session.uploaded_files:
        windows, quality_log = extract_hits_from_file(
            filepath=finfo["filepath"],
            class_idx=finfo["class_idx"],
            flange_id=finfo["flange_id"],
            area_id=finfo["area_id"],
        )
        kept      = len(windows)
        detected  = len(quality_log)
        rejected  = detected - kept

        for win in windows:
            all_waveforms.append(win.tolist())
            all_labels.append(finfo["class_idx"])
            all_flange_groups.append(finfo["flange_id"])
            all_area_groups.append(finfo["area_id"])

        for entry in quality_log:
            entry["filename"]    = finfo["filename"]
            entry["flange_id"]   = finfo["flange_id"]
            entry["class_label"] = finfo["class_label"]
            entry["area_id"]     = finfo["area_id"]
        all_quality_logs.extend(quality_log)

        per_file_stats.append({
            "filename":    finfo["filename"],
            "flange_id":   finfo["flange_id"],
            "class_label": finfo["class_label"],
            "area_id":     finfo["area_id"],
            "detected":    detected,
            "kept":        kept,
            "rejected":    rejected,
        })

    # Store in session (as lists — avoid numpy serialisation issues)
    session.hits = {
        "waveforms":     all_waveforms,
        "labels":        all_labels,
        "flange_groups": all_flange_groups,
        "area_groups":   all_area_groups,
        "n_hits":        len(all_waveforms),
        "hit_window_len": HIT_WINDOW_LEN,
        "sr":            SR,
    }

    # Class distribution
    labels_arr = np.array(all_labels)
    from config import IDX_TO_CLASS
    class_dist = {
        str(IDX_TO_CLASS[idx]): int((labels_arr == idx).sum())
        for idx in [0, 1, 2]
    }

    # Flange distribution
    flanges_arr = np.array(all_flange_groups)
    flange_dist = {
        str(fl): int((flanges_arr == fl).sum())
        for fl in [1, 2, 3, 4]
    }

    # Waveform preview: first kept hit (downsampled for network efficiency)
    preview_waveform: list[float] = []
    preview_rms:      list[float] = []
    if all_waveforms:
        win0 = np.array(all_waveforms[0])
        preview_waveform = downsample(win0, PREVIEW_DOWNSAMPLE)

    # RMS envelope of full first file (for waveform page visualisation)
    first_file_rms_preview: list[float] = []
    first_file_waveform_preview: list[float] = []
    if session.uploaded_files:
        try:
            y0 = load_audio(session.uploaded_files[0]["filepath"])
            rms0, _ = get_rms_envelope(y0)
            first_file_rms_preview = downsample(rms0, 4)
            first_file_waveform_preview = downsample(y0, PREVIEW_DOWNSAMPLE)
        except Exception:
            pass

    stats = {
        "n_files":        len(session.uploaded_files),
        "n_hits_total":   len(all_waveforms),
        "n_hits_rejected": sum(s["rejected"] for s in per_file_stats),
        "class_dist":     class_dist,
        "flange_dist":    flange_dist,
        "per_file":       per_file_stats,
        "quality_log":    all_quality_logs[:200],   # cap to avoid huge response
    }
    session.processing_stats = stats
    session.touch()

    return {
        "status":                    "done",
        "n_hits":                    len(all_waveforms),
        "stats":                     stats,
        "preview_hit_waveform":      preview_waveform,
        "first_file_waveform":       first_file_waveform_preview,
        "first_file_rms":            first_file_rms_preview,
        "downsample_factor":         PREVIEW_DOWNSAMPLE,
        "preview_sr_hz":             SR // PREVIEW_DOWNSAMPLE,
    }


@router.get("/process/hit/{hit_idx}")
async def get_hit_waveform(
    hit_idx: int,
    session_id: str = Header(..., alias="X-Session-Id"),
):
    """Return the waveform for a specific hit index (downsampled)."""
    session = session_manager.get(session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found")
    hits = session.hits
    if not hits or hit_idx >= hits["n_hits"]:
        raise HTTPException(status_code=404, detail=f"Hit {hit_idx} not found")

    win = np.array(hits["waveforms"][hit_idx])
    return {
        "hit_idx":    hit_idx,
        "label":      hits["labels"][hit_idx],
        "flange_id":  hits["flange_groups"][hit_idx],
        "area_id":    hits["area_groups"][hit_idx],
        "waveform":   downsample(win, PREVIEW_DOWNSAMPLE),
        "waveform_full_len": len(hits["waveforms"][hit_idx]),
    }