Spaces:
Running
Running
| """ | |
| models/opera_encoder.py β Fast batched OPERA-CT encoder. | |
| Bypasses OPERA's sequential per-file loop. Instead: | |
| - Preprocesses audio files in parallel using ThreadPoolExecutor (CPU) | |
| - Batches the mel spectrograms and runs one GPU forward pass per batch | |
| - Achieves ~60-80% GPU utilisation vs ~3% with OPERA's default loop | |
| OPERA-CT output: 768-dim L2-normalised embedding per audio clip. | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import librosa | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| OPERA_REPO = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'OPERA') | |
| if OPERA_REPO not in sys.path: | |
| sys.path.insert(0, OPERA_REPO) | |
| OPERA_CT_DIM = 768 | |
| SAMPLE_RATE = 16000 | |
| def _to_wav_if_needed(file_path: str) -> tuple[str, bool]: | |
| """ | |
| Convert non-WAV audio to a temporary WAV file for OPERA compatibility. | |
| Returns (path_to_use, should_delete). | |
| OPERA's get_entire_signal_librosa appends .wav internally, so non-WAV | |
| files must be converted first. | |
| """ | |
| if file_path.lower().endswith('.wav'): | |
| return file_path, False | |
| try: | |
| import librosa | |
| import soundfile as sf | |
| import tempfile | |
| y, sr = librosa.load(file_path, sr=SAMPLE_RATE, mono=True) | |
| tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) | |
| sf.write(tmp.name, y, SAMPLE_RATE) | |
| tmp.close() | |
| return tmp.name, True | |
| except Exception: | |
| return file_path, False | |
| def _get_mel_spectrogram(audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray: | |
| """Inline reimplementation of OPERA's pre_process_audio_mel_t with f_max=8000. | |
| Must match get_entire_signal_librosa which calls pre_process_audio_mel_t(yt, f_max=8000).""" | |
| S = librosa.feature.melspectrogram( | |
| y=audio, sr=sample_rate, n_mels=64, fmin=50, fmax=8000, n_fft=1024, hop_length=512) | |
| S = librosa.power_to_db(S, ref=np.max) | |
| if S.max() != S.min(): | |
| mel_db = (S - S.min()) / (S.max() - S.min()) | |
| else: | |
| mel_db = S | |
| return mel_db.T # (time, mel_bins) | |
| def _preprocess_one(file_path: str, input_sec: int = 8) -> np.ndarray | None: | |
| """ | |
| Load and preprocess one audio file to mel spectrogram. | |
| Inlines OPERA's get_entire_signal_librosa + pre_process_audio_mel_t | |
| to avoid importing src.util (which pulls in matplotlib/seaborn). | |
| """ | |
| sample_rate = SAMPLE_RATE | |
| file_path = os.path.abspath(file_path) | |
| wav_path, should_delete = _to_wav_if_needed(file_path) | |
| try: | |
| data, _ = librosa.load(wav_path, sr=sample_rate, mono=True) | |
| # Trim silence | |
| frame_len = sample_rate // 10 | |
| hop = frame_len // 2 | |
| yt, _ = librosa.effects.trim(data, frame_length=frame_len, hop_length=hop) | |
| # Pad if shorter than input_sec | |
| target_len = input_sec * sample_rate | |
| duration = librosa.get_duration(y=yt, sr=sample_rate) | |
| if duration < input_sec: | |
| # Repeat-pad to target length | |
| n_repeat = int(np.ceil(target_len / len(yt))) | |
| yt = np.tile(yt, n_repeat)[:target_len] | |
| return _get_mel_spectrogram(yt, sample_rate) | |
| except Exception as _e: | |
| import traceback as _tb | |
| sys.stderr.write(f"[opera] preprocess ERROR: {_e}\n{_tb.format_exc()}\n"); sys.stderr.flush() | |
| return None | |
| finally: | |
| if should_delete: | |
| try: | |
| os.unlink(wav_path) | |
| except Exception: | |
| pass | |
| class OPERAEncoder: | |
| """ | |
| Fast batched OPERA-CT encoder. | |
| Preprocessing runs in parallel threads (CPU-bound). | |
| Inference runs in batches on GPU. | |
| Parameters | |
| ---------- | |
| pretrain : 'operaCT' (HT-SAT, 768-dim) β only CT supported here | |
| input_sec : audio clip length in seconds (default 8) | |
| batch_size : GPU batch size (default 16 β safe for GTX 1650 4GB) | |
| n_workers : CPU threads for parallel audio preprocessing (default 4) | |
| """ | |
| def __init__(self, | |
| pretrain: str = 'operaCT', | |
| input_sec: int = 8, | |
| batch_size: int = 16, | |
| n_workers: int = 4): | |
| self.pretrain = pretrain | |
| self.input_sec = input_sec | |
| self.batch_size = batch_size | |
| self.n_workers = n_workers | |
| self.dim = OPERA_CT_DIM | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self._model = self._load_model() | |
| print(f"[OPERAEncoder] {pretrain} on {self.device} | " | |
| f"batch={batch_size} | workers={n_workers} | dim={self.dim}") | |
| def _load_model(self): | |
| orig_dir = os.getcwd() | |
| os.chdir(OPERA_REPO) | |
| try: | |
| from src.benchmark.model_util import get_encoder_path, initialize_pretrained_model | |
| ckpt_path = get_encoder_path(self.pretrain) | |
| ckpt = torch.load(ckpt_path, map_location=self.device) | |
| model = initialize_pretrained_model(self.pretrain) | |
| model.load_state_dict(ckpt['state_dict'], strict=False) | |
| model = model.to(self.device) | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| return model | |
| finally: | |
| os.chdir(orig_dir) | |
| def _infer_batch(self, specs: list) -> np.ndarray: | |
| """ | |
| Run GPU forward pass on a list of mel spectrograms. | |
| get_entire_signal_librosa returns (time, mel_bins). | |
| Model expects (N, 1, mel_bins, time). | |
| Returns: np.ndarray (N, 768) | |
| """ | |
| # specs are (time, mel_bins) from get_entire_signal_librosa | |
| # model.forward does unsqueeze(1) internally β expects (N, time, mel_bins) | |
| # Pad/truncate along time dimension (axis 0) to match within batch | |
| target_time = max(s.shape[0] for s in specs) | |
| padded = [] | |
| for s in specs: | |
| if s.shape[0] < target_time: | |
| s = np.pad(s, ((0, target_time - s.shape[0]), (0, 0))) | |
| else: | |
| s = s[:target_time, :] | |
| padded.append(s) | |
| x = torch.tensor(np.stack(padded), dtype=torch.float32) | |
| x = x.to(self.device) # (N, time, mel_bins) | |
| with torch.no_grad(): | |
| features = self._model.extract_feature(x, self.dim) # (N, 768) | |
| features = features.cpu().numpy() | |
| return features | |
| def encode(self, audio_path: str) -> np.ndarray: | |
| """Encode a single file. Returns (768,) L2-normalised embedding.""" | |
| return self.encode_batch([audio_path])[0] | |
| def encode_batch(self, audio_paths: list) -> np.ndarray: | |
| """ | |
| Encode a list of audio files β (N, 768) L2-normalised embeddings. | |
| Failed files return a zero vector (handled upstream). | |
| """ | |
| N = len(audio_paths) | |
| results = [None] * N | |
| valid_idx = [] # indices with successful preprocessing | |
| # ββ Parallel CPU preprocessing ββββββββββββββββββββββββββββββββββββββ | |
| with ThreadPoolExecutor(max_workers=self.n_workers) as pool: | |
| futures = { | |
| pool.submit(_preprocess_one, p, self.input_sec): i | |
| for i, p in enumerate(audio_paths) | |
| } | |
| for future in as_completed(futures): | |
| i = futures[future] | |
| spec = future.result() | |
| if spec is not None: | |
| results[i] = spec | |
| valid_idx.append(i) | |
| if not valid_idx: | |
| return np.zeros((N, self.dim), dtype=np.float32) | |
| valid_idx.sort() | |
| # ββ Batched GPU inference ββββββββββββββββββββββββββββββββββββββββββββ | |
| all_embeddings = np.zeros((N, self.dim), dtype=np.float32) | |
| for batch_start in range(0, len(valid_idx), self.batch_size): | |
| batch_idx = valid_idx[batch_start: batch_start + self.batch_size] | |
| batch_specs = [results[i] for i in batch_idx] | |
| try: | |
| embs = self._infer_batch(batch_specs) | |
| for local_i, global_i in enumerate(batch_idx): | |
| all_embeddings[global_i] = embs[local_i] | |
| except Exception as e: | |
| # Fall back to one-by-one for this batch | |
| for global_i, spec in zip(batch_idx, batch_specs): | |
| try: | |
| emb = self._infer_batch([spec]) | |
| all_embeddings[global_i] = emb[0] | |
| except Exception: | |
| pass # stays as zero vector | |
| # L2 normalise (skip zero rows) | |
| norms = np.linalg.norm(all_embeddings, axis=1, keepdims=True) | |
| norms = np.where(norms > 0, norms, 1.0) | |
| all_embeddings = all_embeddings / norms | |
| return all_embeddings | |