Spaces:
Running
Running
| """ | |
| Router: /api/process | |
| Runs the full hit detection pipeline across all uploaded training files. | |
| Returns extracted hit metadata and waveform previews for the frontend. | |
| """ | |
| import numpy as np | |
| from fastapi import APIRouter, Header, HTTPException, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| from session import session_manager | |
| from utils.audio import extract_hits_from_file, load_audio, get_rms_envelope | |
| from config import SR, HIT_WINDOW_LEN | |
| router = APIRouter(prefix="/api", tags=["process"]) | |
| # Downsample factor for waveform preview (send 1 in N samples to frontend) | |
| PREVIEW_DOWNSAMPLE = 64 # 48000 Hz → 750 Hz preview resolution | |
| def downsample(arr: np.ndarray, factor: int) -> list[float]: | |
| """Return a downsampled version of the array as a Python list.""" | |
| return arr[::factor].tolist() | |
| async def process_files(session_id: str = Header(..., alias="X-Session-Id")): | |
| """ | |
| Extract hits from all uploaded training files. | |
| Stores results in session.hits and session.processing_stats. | |
| Returns per-file stats, class distribution, and waveform preview of first hit. | |
| """ | |
| session = session_manager.get(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found or expired") | |
| if not session.uploaded_files: | |
| raise HTTPException(status_code=400, detail="No files uploaded. Call POST /api/upload first.") | |
| all_waveforms: list[list[float]] = [] | |
| all_labels: list[int] = [] | |
| all_flange_groups: list[int] = [] | |
| all_area_groups: list[int] = [] | |
| per_file_stats: list[dict] = [] | |
| all_quality_logs: list[dict] = [] | |
| for finfo in session.uploaded_files: | |
| windows, quality_log = extract_hits_from_file( | |
| filepath=finfo["filepath"], | |
| class_idx=finfo["class_idx"], | |
| flange_id=finfo["flange_id"], | |
| area_id=finfo["area_id"], | |
| ) | |
| kept = len(windows) | |
| detected = len(quality_log) | |
| rejected = detected - kept | |
| for win in windows: | |
| all_waveforms.append(win.tolist()) | |
| all_labels.append(finfo["class_idx"]) | |
| all_flange_groups.append(finfo["flange_id"]) | |
| all_area_groups.append(finfo["area_id"]) | |
| for entry in quality_log: | |
| entry["filename"] = finfo["filename"] | |
| entry["flange_id"] = finfo["flange_id"] | |
| entry["class_label"] = finfo["class_label"] | |
| entry["area_id"] = finfo["area_id"] | |
| all_quality_logs.extend(quality_log) | |
| per_file_stats.append({ | |
| "filename": finfo["filename"], | |
| "flange_id": finfo["flange_id"], | |
| "class_label": finfo["class_label"], | |
| "area_id": finfo["area_id"], | |
| "detected": detected, | |
| "kept": kept, | |
| "rejected": rejected, | |
| }) | |
| # Store in session (as lists — avoid numpy serialisation issues) | |
| session.hits = { | |
| "waveforms": all_waveforms, | |
| "labels": all_labels, | |
| "flange_groups": all_flange_groups, | |
| "area_groups": all_area_groups, | |
| "n_hits": len(all_waveforms), | |
| "hit_window_len": HIT_WINDOW_LEN, | |
| "sr": SR, | |
| } | |
| # Class distribution | |
| labels_arr = np.array(all_labels) | |
| from config import IDX_TO_CLASS | |
| class_dist = { | |
| str(IDX_TO_CLASS[idx]): int((labels_arr == idx).sum()) | |
| for idx in [0, 1, 2] | |
| } | |
| # Flange distribution | |
| flanges_arr = np.array(all_flange_groups) | |
| flange_dist = { | |
| str(fl): int((flanges_arr == fl).sum()) | |
| for fl in [1, 2, 3, 4] | |
| } | |
| # Waveform preview: first kept hit (downsampled for network efficiency) | |
| preview_waveform: list[float] = [] | |
| preview_rms: list[float] = [] | |
| if all_waveforms: | |
| win0 = np.array(all_waveforms[0]) | |
| preview_waveform = downsample(win0, PREVIEW_DOWNSAMPLE) | |
| # RMS envelope of full first file (for waveform page visualisation) | |
| first_file_rms_preview: list[float] = [] | |
| first_file_waveform_preview: list[float] = [] | |
| if session.uploaded_files: | |
| try: | |
| y0 = load_audio(session.uploaded_files[0]["filepath"]) | |
| rms0, _ = get_rms_envelope(y0) | |
| first_file_rms_preview = downsample(rms0, 4) | |
| first_file_waveform_preview = downsample(y0, PREVIEW_DOWNSAMPLE) | |
| except Exception: | |
| pass | |
| stats = { | |
| "n_files": len(session.uploaded_files), | |
| "n_hits_total": len(all_waveforms), | |
| "n_hits_rejected": sum(s["rejected"] for s in per_file_stats), | |
| "class_dist": class_dist, | |
| "flange_dist": flange_dist, | |
| "per_file": per_file_stats, | |
| "quality_log": all_quality_logs[:200], # cap to avoid huge response | |
| } | |
| session.processing_stats = stats | |
| session.touch() | |
| return { | |
| "status": "done", | |
| "n_hits": len(all_waveforms), | |
| "stats": stats, | |
| "preview_hit_waveform": preview_waveform, | |
| "first_file_waveform": first_file_waveform_preview, | |
| "first_file_rms": first_file_rms_preview, | |
| "downsample_factor": PREVIEW_DOWNSAMPLE, | |
| "preview_sr_hz": SR // PREVIEW_DOWNSAMPLE, | |
| } | |
| async def get_hit_waveform( | |
| hit_idx: int, | |
| session_id: str = Header(..., alias="X-Session-Id"), | |
| ): | |
| """Return the waveform for a specific hit index (downsampled).""" | |
| session = session_manager.get(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| hits = session.hits | |
| if not hits or hit_idx >= hits["n_hits"]: | |
| raise HTTPException(status_code=404, detail=f"Hit {hit_idx} not found") | |
| win = np.array(hits["waveforms"][hit_idx]) | |
| return { | |
| "hit_idx": hit_idx, | |
| "label": hits["labels"][hit_idx], | |
| "flange_id": hits["flange_groups"][hit_idx], | |
| "area_id": hits["area_groups"][hit_idx], | |
| "waveform": downsample(win, PREVIEW_DOWNSAMPLE), | |
| "waveform_full_len": len(hits["waveforms"][hit_idx]), | |
| } | |