snare_scout / src /embeddings_panns.py
john221113's picture
Move files to src
1a46553
"""
PANNs Audio Embeddings - v6.6.1
Designed for:
- Full-clip embeddings (call once per clip)
- CPU-safe default (avoids MPS weirdness)
"""
from __future__ import annotations
import os
import numpy as np
try:
import librosa
except ImportError:
librosa = None
try:
import soxr
SOXR_AVAILABLE = True
except ImportError:
SOXR_AVAILABLE = False
_PANNS_MODEL = None
def get_panns_model(checkpoint_path: str = "Cnn14_mAP=0.431.pth", device: str = "cpu"):
global _PANNS_MODEL
if _PANNS_MODEL is not None:
return _PANNS_MODEL
try:
from panns_inference import AudioTagging
except ImportError as e:
raise ImportError("panns_inference not installed. Try: pip install panns-inference") from e
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Missing PANNs checkpoint: {checkpoint_path}\n"
"Download it (mac):\n"
" curl -L 'https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1' -o Cnn14_mAP=0.431.pth"
)
print("[panns] Loading PANNs model...")
print(f"[panns] Device: {device}")
print(f"[panns] Checkpoint: {checkpoint_path}")
model = AudioTagging(checkpoint_path=checkpoint_path, device=device)
# ✅ Put underlying torch model in eval() if it exists (wrapper may not have eval)
# different versions expose the torch module under different attributes
if hasattr(model, "model") and hasattr(model.model, "eval"):
model.model.eval()
elif hasattr(model, "audio_model") and hasattr(model.audio_model, "eval"):
model.audio_model.eval()
elif hasattr(model, "eval"):
# only call if present
model.eval()
_PANNS_MODEL = model
print("[panns] Loaded")
return _PANNS_MODEL
def _resample(audio: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
if sr == target_sr:
return audio.astype(np.float32, copy=False)
audio = audio.astype(np.float32, copy=False)
if SOXR_AVAILABLE:
return soxr.resample(audio, sr, target_sr, quality="HQ").astype(np.float32, copy=False)
if librosa is None:
raise ImportError("Need librosa (or soxr) to resample for PANNs")
return librosa.resample(audio, orig_sr=sr, target_sr=target_sr, res_type="kaiser_best").astype(np.float32, copy=False)
def embed_audio_panns(audio: np.ndarray, sr: int, target_sr: int = 32000) -> np.ndarray:
"""
Returns L2-normalized PANNs embedding (float32).
Call once per clip (full) to avoid any multi-call weirdness.
"""
if audio.ndim > 1:
audio = np.mean(audio, axis=0)
audio = audio.astype(np.float32, copy=False)
# Peak normalize + DC remove
peak = float(np.max(np.abs(audio))) if audio.size else 0.0
if peak > 1e-6:
audio = audio / peak
audio = audio - float(np.mean(audio)) if audio.size else audio
audio = _resample(audio, sr, target_sr)
model = get_panns_model(device="cpu")
audio = np.ascontiguousarray(audio, dtype=np.float32)
clipwise_output, embedding = model.inference(audio[np.newaxis, :])
emb = np.asarray(embedding).reshape(-1).astype(np.float32, copy=False)
n = float(np.linalg.norm(emb))
if n > 1e-9:
emb = emb / n
return emb