respitriage / models /opera_encoder.py
SujalSha's picture
Fix mel spectrogram f_max=8000 to match OPERA training preprocessing
c14ab18 verified
"""
models/opera_encoder.py β€” Fast batched OPERA-CT encoder.
Bypasses OPERA's sequential per-file loop. Instead:
- Preprocesses audio files in parallel using ThreadPoolExecutor (CPU)
- Batches the mel spectrograms and runs one GPU forward pass per batch
- Achieves ~60-80% GPU utilisation vs ~3% with OPERA's default loop
OPERA-CT output: 768-dim L2-normalised embedding per audio clip.
"""
import os
import sys
import numpy as np
import torch
import librosa
from concurrent.futures import ThreadPoolExecutor, as_completed
OPERA_REPO = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'OPERA')
if OPERA_REPO not in sys.path:
sys.path.insert(0, OPERA_REPO)
OPERA_CT_DIM = 768
SAMPLE_RATE = 16000
def _to_wav_if_needed(file_path: str) -> tuple[str, bool]:
"""
Convert non-WAV audio to a temporary WAV file for OPERA compatibility.
Returns (path_to_use, should_delete).
OPERA's get_entire_signal_librosa appends .wav internally, so non-WAV
files must be converted first.
"""
if file_path.lower().endswith('.wav'):
return file_path, False
try:
import librosa
import soundfile as sf
import tempfile
y, sr = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
sf.write(tmp.name, y, SAMPLE_RATE)
tmp.close()
return tmp.name, True
except Exception:
return file_path, False
def _get_mel_spectrogram(audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray:
"""Inline reimplementation of OPERA's pre_process_audio_mel_t with f_max=8000.
Must match get_entire_signal_librosa which calls pre_process_audio_mel_t(yt, f_max=8000)."""
S = librosa.feature.melspectrogram(
y=audio, sr=sample_rate, n_mels=64, fmin=50, fmax=8000, n_fft=1024, hop_length=512)
S = librosa.power_to_db(S, ref=np.max)
if S.max() != S.min():
mel_db = (S - S.min()) / (S.max() - S.min())
else:
mel_db = S
return mel_db.T # (time, mel_bins)
def _preprocess_one(file_path: str, input_sec: int = 8) -> np.ndarray | None:
"""
Load and preprocess one audio file to mel spectrogram.
Inlines OPERA's get_entire_signal_librosa + pre_process_audio_mel_t
to avoid importing src.util (which pulls in matplotlib/seaborn).
"""
sample_rate = SAMPLE_RATE
file_path = os.path.abspath(file_path)
wav_path, should_delete = _to_wav_if_needed(file_path)
try:
data, _ = librosa.load(wav_path, sr=sample_rate, mono=True)
# Trim silence
frame_len = sample_rate // 10
hop = frame_len // 2
yt, _ = librosa.effects.trim(data, frame_length=frame_len, hop_length=hop)
# Pad if shorter than input_sec
target_len = input_sec * sample_rate
duration = librosa.get_duration(y=yt, sr=sample_rate)
if duration < input_sec:
# Repeat-pad to target length
n_repeat = int(np.ceil(target_len / len(yt)))
yt = np.tile(yt, n_repeat)[:target_len]
return _get_mel_spectrogram(yt, sample_rate)
except Exception as _e:
import traceback as _tb
sys.stderr.write(f"[opera] preprocess ERROR: {_e}\n{_tb.format_exc()}\n"); sys.stderr.flush()
return None
finally:
if should_delete:
try:
os.unlink(wav_path)
except Exception:
pass
class OPERAEncoder:
"""
Fast batched OPERA-CT encoder.
Preprocessing runs in parallel threads (CPU-bound).
Inference runs in batches on GPU.
Parameters
----------
pretrain : 'operaCT' (HT-SAT, 768-dim) β€” only CT supported here
input_sec : audio clip length in seconds (default 8)
batch_size : GPU batch size (default 16 β€” safe for GTX 1650 4GB)
n_workers : CPU threads for parallel audio preprocessing (default 4)
"""
def __init__(self,
pretrain: str = 'operaCT',
input_sec: int = 8,
batch_size: int = 16,
n_workers: int = 4):
self.pretrain = pretrain
self.input_sec = input_sec
self.batch_size = batch_size
self.n_workers = n_workers
self.dim = OPERA_CT_DIM
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._model = self._load_model()
print(f"[OPERAEncoder] {pretrain} on {self.device} | "
f"batch={batch_size} | workers={n_workers} | dim={self.dim}")
def _load_model(self):
orig_dir = os.getcwd()
os.chdir(OPERA_REPO)
try:
from src.benchmark.model_util import get_encoder_path, initialize_pretrained_model
ckpt_path = get_encoder_path(self.pretrain)
ckpt = torch.load(ckpt_path, map_location=self.device)
model = initialize_pretrained_model(self.pretrain)
model.load_state_dict(ckpt['state_dict'], strict=False)
model = model.to(self.device)
model.eval()
for p in model.parameters():
p.requires_grad = False
return model
finally:
os.chdir(orig_dir)
def _infer_batch(self, specs: list) -> np.ndarray:
"""
Run GPU forward pass on a list of mel spectrograms.
get_entire_signal_librosa returns (time, mel_bins).
Model expects (N, 1, mel_bins, time).
Returns: np.ndarray (N, 768)
"""
# specs are (time, mel_bins) from get_entire_signal_librosa
# model.forward does unsqueeze(1) internally β†’ expects (N, time, mel_bins)
# Pad/truncate along time dimension (axis 0) to match within batch
target_time = max(s.shape[0] for s in specs)
padded = []
for s in specs:
if s.shape[0] < target_time:
s = np.pad(s, ((0, target_time - s.shape[0]), (0, 0)))
else:
s = s[:target_time, :]
padded.append(s)
x = torch.tensor(np.stack(padded), dtype=torch.float32)
x = x.to(self.device) # (N, time, mel_bins)
with torch.no_grad():
features = self._model.extract_feature(x, self.dim) # (N, 768)
features = features.cpu().numpy()
return features
def encode(self, audio_path: str) -> np.ndarray:
"""Encode a single file. Returns (768,) L2-normalised embedding."""
return self.encode_batch([audio_path])[0]
def encode_batch(self, audio_paths: list) -> np.ndarray:
"""
Encode a list of audio files β†’ (N, 768) L2-normalised embeddings.
Failed files return a zero vector (handled upstream).
"""
N = len(audio_paths)
results = [None] * N
valid_idx = [] # indices with successful preprocessing
# ── Parallel CPU preprocessing ──────────────────────────────────────
with ThreadPoolExecutor(max_workers=self.n_workers) as pool:
futures = {
pool.submit(_preprocess_one, p, self.input_sec): i
for i, p in enumerate(audio_paths)
}
for future in as_completed(futures):
i = futures[future]
spec = future.result()
if spec is not None:
results[i] = spec
valid_idx.append(i)
if not valid_idx:
return np.zeros((N, self.dim), dtype=np.float32)
valid_idx.sort()
# ── Batched GPU inference ────────────────────────────────────────────
all_embeddings = np.zeros((N, self.dim), dtype=np.float32)
for batch_start in range(0, len(valid_idx), self.batch_size):
batch_idx = valid_idx[batch_start: batch_start + self.batch_size]
batch_specs = [results[i] for i in batch_idx]
try:
embs = self._infer_batch(batch_specs)
for local_i, global_i in enumerate(batch_idx):
all_embeddings[global_i] = embs[local_i]
except Exception as e:
# Fall back to one-by-one for this batch
for global_i, spec in zip(batch_idx, batch_specs):
try:
emb = self._infer_batch([spec])
all_embeddings[global_i] = emb[0]
except Exception:
pass # stays as zero vector
# L2 normalise (skip zero rows)
norms = np.linalg.norm(all_embeddings, axis=1, keepdims=True)
norms = np.where(norms > 0, norms, 1.0)
all_embeddings = all_embeddings / norms
return all_embeddings