snare_scout / preprocessing.py
john221113's picture
Initial
5076f2a
"""
Stable Preprocessing Pipeline - v6.6.1
SINGLE source of truth for:
- Canonical SR
- Onset window lengths
- Transient/tail slicing
Key principle: Same audio → Same representation (deterministic)
"""
import os
import io
import numpy as np
import librosa
import soundfile as sf
try:
import soxr
SOXR_AVAILABLE = True
except ImportError:
SOXR_AVAILABLE = False
print("[preprocessing] Warning: soxr not available, using librosa (less deterministic)")
# =========================
# CANONICAL SETTINGS
# =========================
CANONICAL_SR = 48000
# Window around onset (ms)
ONSET_PRE_MS = 15
ONSET_POST_MS = 735 # total window = 750ms
# View presets (all in ms, relative to START OF WINDOW)
# NOTE: start of window is ONSET_PRE_MS before onset.
VIEW_PRESETS = {
# Optimized for short one-shots / CNN-style event models
"hits": {
"TRANS_END_MS": 85, # first 85ms of window (15ms pre + 70ms post)
"TAIL_START_MS": 30, # skip earliest transient region
"TAIL_END_MS": 650, # capture body/decay
},
# Sometimes helps transformer encoders on micro-clips (requires reindex to compare)
"transformer": {
"TRANS_END_MS": 140, # longer transient context
"TAIL_START_MS": 40,
"TAIL_END_MS": 700,
}
}
DEFAULT_VIEW_PRESET = os.getenv("SCOUT_VIEW_PRESET", "hits").strip().lower()
if DEFAULT_VIEW_PRESET not in VIEW_PRESETS:
DEFAULT_VIEW_PRESET = "hits"
# Export these so scout.py can’t drift
TRANS_END_MS = VIEW_PRESETS[DEFAULT_VIEW_PRESET]["TRANS_END_MS"]
TAIL_START_MS = VIEW_PRESETS[DEFAULT_VIEW_PRESET]["TAIL_START_MS"]
TAIL_END_MS = VIEW_PRESETS[DEFAULT_VIEW_PRESET]["TAIL_END_MS"]
def canonicalize_audio(audio: np.ndarray, sr: int):
"""
Deterministic audio canonicalization:
1) Resample to CANONICAL_SR
2) Mono
3) Peak normalize to ±1
4) Remove DC offset
"""
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
audio = audio.astype(np.float32, copy=False)
if sr != CANONICAL_SR:
if SOXR_AVAILABLE:
audio = soxr.resample(audio, sr, CANONICAL_SR, quality="HQ")
else:
audio = librosa.resample(audio, orig_sr=sr, target_sr=CANONICAL_SR, res_type="kaiser_best")
sr = CANONICAL_SR
peak = float(np.max(np.abs(audio))) if audio.size else 0.0
if peak > 1e-6:
audio = audio / peak
audio = audio - float(np.mean(audio)) if audio.size else audio
return audio.astype(np.float32, copy=False), CANONICAL_SR
def detect_primary_onset_stable(audio: np.ndarray, sr: int) -> int:
"""
Deterministic onset detection with small zero-crossing refinement.
Returns onset sample index.
"""
hop_length = 256
onset_env = librosa.onset.onset_strength(
y=audio,
sr=sr,
hop_length=hop_length,
aggregate=np.median,
center=False
)
peaks = librosa.util.peak_pick(
onset_env,
pre_max=3,
post_max=3,
pre_avg=3,
post_avg=5,
delta=0.05,
wait=10
)
if len(peaks) == 0:
return int(0.1 * sr)
strongest_peak = int(peaks[int(np.argmax(onset_env[peaks]))])
onset_sample = int(librosa.frames_to_samples(strongest_peak, hop_length=hop_length))
# zero-crossing refinement
window = 100
start = max(0, onset_sample - window)
end = min(len(audio), onset_sample + window)
if end > start + 2:
seg = audio[start:end]
zc = np.where(np.diff(np.sign(seg)))[0]
if zc.size:
center = window if onset_sample >= window else onset_sample - start
onset_sample = start + int(zc[int(np.argmin(np.abs(zc - center)))])
return int(onset_sample)
def extract_canonical_window(audio: np.ndarray, sr: int, onset_sample: int) -> np.ndarray:
"""
Extract fixed-length window around onset.
Always returns exactly (ONSET_PRE_MS + ONSET_POST_MS) ms length at sr.
"""
pre_samples = int(ONSET_PRE_MS * sr / 1000.0)
post_samples = int(ONSET_POST_MS * sr / 1000.0)
expected = pre_samples + post_samples
start = max(0, onset_sample - pre_samples)
end = min(len(audio), onset_sample + post_samples)
w = audio[start:end].astype(np.float32, copy=False)
if w.size < expected:
w = np.pad(w, (0, expected - w.size), mode="constant")
elif w.size > expected:
w = w[:expected]
return w.astype(np.float32, copy=False)
def preprocess_audio_stable(audio_bytes: bytes):
"""
MASTER preprocessing for QUERY uploads (file bytes):
load -> canonicalize -> detect onset -> extract fixed window
"""
audio, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=False)
audio, sr = canonicalize_audio(audio, sr)
onset_sample = detect_primary_onset_stable(audio, sr)
window = extract_canonical_window(audio, sr, onset_sample)
return {
"audio": window,
"sr": sr,
"onset_time": onset_sample / sr,
"onset_sample": onset_sample
}
def slice_views_stable(processed: dict, view_preset: str | None = None):
"""
Create full/trans/tail views from the canonical window.
view_preset:
- None => uses DEFAULT_VIEW_PRESET (env SCOUT_VIEW_PRESET)
- "hits" or "transformer"
"""
audio = processed["audio"]
sr = processed["sr"]
preset = (view_preset or DEFAULT_VIEW_PRESET).strip().lower()
if preset not in VIEW_PRESETS:
preset = "hits"
trans_end = int(VIEW_PRESETS[preset]["TRANS_END_MS"] * sr / 1000.0)
tail_start = int(VIEW_PRESETS[preset]["TAIL_START_MS"] * sr / 1000.0)
tail_end = int(VIEW_PRESETS[preset]["TAIL_END_MS"] * sr / 1000.0)
full = audio
trans = audio[:max(0, min(trans_end, audio.size))]
tail = audio[max(0, min(tail_start, audio.size)):max(0, min(tail_end, audio.size))]
return {"full": full, "trans": trans, "tail": tail}
def verify_stability():
print("[preprocessing] Running stability test...")
sr = 48000
t = np.linspace(0, 1.0, int(sr * 1.0), endpoint=False)
audio = (np.sin(2 * np.pi * 200 * t) * np.exp(-t * 5)).astype(np.float32)
bio = io.BytesIO()
sf.write(bio, audio, sr, format="WAV")
audio_bytes = bio.getvalue()
outs = []
for _ in range(5):
p = preprocess_audio_stable(audio_bytes)
outs.append(p["audio"])
for i in range(1, len(outs)):
diff = float(np.max(np.abs(outs[0] - outs[i])))
if diff > 1e-6:
print(f"[preprocessing] ⚠️ Instability detected: {diff}")
return False
print("[preprocessing] ✓ Stability test passed")
return True
if __name__ == "__main__":
verify_stability()