flange-ml-api / routers /process.py
bbmb's picture
Initial backend deployment
2c11783
"""
Router: /api/process
Runs the full hit detection pipeline across all uploaded training files.
Returns extracted hit metadata and waveform previews for the frontend.
"""
import numpy as np
from fastapi import APIRouter, Header, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from session import session_manager
from utils.audio import extract_hits_from_file, load_audio, get_rms_envelope
from config import SR, HIT_WINDOW_LEN
router = APIRouter(prefix="/api", tags=["process"])
# Downsample factor for waveform preview (send 1 in N samples to frontend)
PREVIEW_DOWNSAMPLE = 64 # 48000 Hz → 750 Hz preview resolution
def downsample(arr: np.ndarray, factor: int) -> list[float]:
"""Return a downsampled version of the array as a Python list."""
return arr[::factor].tolist()
@router.post("/process")
async def process_files(session_id: str = Header(..., alias="X-Session-Id")):
"""
Extract hits from all uploaded training files.
Stores results in session.hits and session.processing_stats.
Returns per-file stats, class distribution, and waveform preview of first hit.
"""
session = session_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found or expired")
if not session.uploaded_files:
raise HTTPException(status_code=400, detail="No files uploaded. Call POST /api/upload first.")
all_waveforms: list[list[float]] = []
all_labels: list[int] = []
all_flange_groups: list[int] = []
all_area_groups: list[int] = []
per_file_stats: list[dict] = []
all_quality_logs: list[dict] = []
for finfo in session.uploaded_files:
windows, quality_log = extract_hits_from_file(
filepath=finfo["filepath"],
class_idx=finfo["class_idx"],
flange_id=finfo["flange_id"],
area_id=finfo["area_id"],
)
kept = len(windows)
detected = len(quality_log)
rejected = detected - kept
for win in windows:
all_waveforms.append(win.tolist())
all_labels.append(finfo["class_idx"])
all_flange_groups.append(finfo["flange_id"])
all_area_groups.append(finfo["area_id"])
for entry in quality_log:
entry["filename"] = finfo["filename"]
entry["flange_id"] = finfo["flange_id"]
entry["class_label"] = finfo["class_label"]
entry["area_id"] = finfo["area_id"]
all_quality_logs.extend(quality_log)
per_file_stats.append({
"filename": finfo["filename"],
"flange_id": finfo["flange_id"],
"class_label": finfo["class_label"],
"area_id": finfo["area_id"],
"detected": detected,
"kept": kept,
"rejected": rejected,
})
# Store in session (as lists — avoid numpy serialisation issues)
session.hits = {
"waveforms": all_waveforms,
"labels": all_labels,
"flange_groups": all_flange_groups,
"area_groups": all_area_groups,
"n_hits": len(all_waveforms),
"hit_window_len": HIT_WINDOW_LEN,
"sr": SR,
}
# Class distribution
labels_arr = np.array(all_labels)
from config import IDX_TO_CLASS
class_dist = {
str(IDX_TO_CLASS[idx]): int((labels_arr == idx).sum())
for idx in [0, 1, 2]
}
# Flange distribution
flanges_arr = np.array(all_flange_groups)
flange_dist = {
str(fl): int((flanges_arr == fl).sum())
for fl in [1, 2, 3, 4]
}
# Waveform preview: first kept hit (downsampled for network efficiency)
preview_waveform: list[float] = []
preview_rms: list[float] = []
if all_waveforms:
win0 = np.array(all_waveforms[0])
preview_waveform = downsample(win0, PREVIEW_DOWNSAMPLE)
# RMS envelope of full first file (for waveform page visualisation)
first_file_rms_preview: list[float] = []
first_file_waveform_preview: list[float] = []
if session.uploaded_files:
try:
y0 = load_audio(session.uploaded_files[0]["filepath"])
rms0, _ = get_rms_envelope(y0)
first_file_rms_preview = downsample(rms0, 4)
first_file_waveform_preview = downsample(y0, PREVIEW_DOWNSAMPLE)
except Exception:
pass
stats = {
"n_files": len(session.uploaded_files),
"n_hits_total": len(all_waveforms),
"n_hits_rejected": sum(s["rejected"] for s in per_file_stats),
"class_dist": class_dist,
"flange_dist": flange_dist,
"per_file": per_file_stats,
"quality_log": all_quality_logs[:200], # cap to avoid huge response
}
session.processing_stats = stats
session.touch()
return {
"status": "done",
"n_hits": len(all_waveforms),
"stats": stats,
"preview_hit_waveform": preview_waveform,
"first_file_waveform": first_file_waveform_preview,
"first_file_rms": first_file_rms_preview,
"downsample_factor": PREVIEW_DOWNSAMPLE,
"preview_sr_hz": SR // PREVIEW_DOWNSAMPLE,
}
@router.get("/process/hit/{hit_idx}")
async def get_hit_waveform(
hit_idx: int,
session_id: str = Header(..., alias="X-Session-Id"),
):
"""Return the waveform for a specific hit index (downsampled)."""
session = session_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found")
hits = session.hits
if not hits or hit_idx >= hits["n_hits"]:
raise HTTPException(status_code=404, detail=f"Hit {hit_idx} not found")
win = np.array(hits["waveforms"][hit_idx])
return {
"hit_idx": hit_idx,
"label": hits["labels"][hit_idx],
"flange_id": hits["flange_groups"][hit_idx],
"area_id": hits["area_groups"][hit_idx],
"waveform": downsample(win, PREVIEW_DOWNSAMPLE),
"waveform_full_len": len(hits["waveforms"][hit_idx]),
}