""" Vocal stem separation using Demucs (HTDemucs). Extracts clean vocals from a mixed audio track for downstream transcription. Uses htdemucs_ft (fine-tuned) for best quality (~9.2 dB SDR on MUSDB18-HQ). """ import logging from pathlib import Path from typing import Optional import numpy as np import torch import torchaudio logger = logging.getLogger(__name__) class VocalSeparator: """ Separate vocals from mixed audio using Demucs HTDemucs model. The separated vocals are significantly cleaner for ASR than the original mix, reducing transcription WER by ~3-5% (per arxiv:2506.15514). """ def __init__( self, model_name: str = "htdemucs_ft", device: Optional[str] = None, segment_seconds: float = 7.8, overlap: float = 0.25, shifts: int = 1, ): """ Args: model_name: Demucs model to use. Options: - "htdemucs_ft": Best quality, per-source fine-tuned (~9.2 dB SDR) - "htdemucs": Base model, slightly faster download (~8.7 dB SDR) - "htdemucs_6s": 6-stem (adds guitar, piano) device: "cuda", "cpu", or "mps". Auto-detected if None. segment_seconds: Processing chunk size. Lower = less VRAM. - 7.8: Default (matches training), ~4-6 GB VRAM - 4.0: For 8 GB GPUs - 2.0: For CPU processing overlap: Overlap ratio between chunks (0.25 = 25%, matches paper). shifts: Test-time shift augmentation. 1=disabled, 5-10=better quality but N× slower. """ self.model_name = model_name self.segment_seconds = segment_seconds self.overlap = overlap self.shifts = shifts if device is None: if torch.cuda.is_available(): self.device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" else: self.device = device self._model = None self._separator = None def _load_model(self): """Lazy-load model on first use.""" if self._model is not None: return try: # Try high-level Separator API first (demucs >= 4.1) from demucs.api import Separator self._separator = Separator( model=self.model_name, device=self.device, segment=self.segment_seconds, overlap=self.overlap, ) logger.info(f"Loaded Demucs via Separator API: {self.model_name} on {self.device}") except ImportError: # Fallback to low-level API from demucs.pretrained import get_model self._model = get_model(self.model_name) self._model.eval() self._model.to(self.device) logger.info(f"Loaded Demucs via low-level API: {self.model_name} on {self.device}") @property def sample_rate(self) -> int: """Demucs native sample rate (always 44100).""" return 44100 def separate(self, audio_path: str) -> dict[str, torch.Tensor]: """ Separate audio into stems. Args: audio_path: Path to audio file (any format supported by torchaudio) Returns: Dict mapping stem name → tensor [channels, samples] at 44100 Hz. Keys: "drums", "bass", "other", "vocals" """ self._load_model() # Load audio wav, sr = torchaudio.load(audio_path) # Resample to model's native 44100 Hz if sr != self.sample_rate: wav = torchaudio.functional.resample(wav, sr, self.sample_rate) # Ensure stereo (Demucs expects 2-channel) if wav.shape[0] == 1: wav = wav.repeat(2, 1) elif wav.shape[0] > 2: wav = wav[:2] # Take first 2 channels if self._separator is not None: # High-level API _, stems = self._separator.separate_tensor(wav.to(self.device)) return stems else: # Low-level API from demucs.apply import apply_model wav_batch = wav.unsqueeze(0).to(self.device) # [1, 2, N] with torch.no_grad(): sources = apply_model( self._model, wav_batch, device=self.device, shifts=self.shifts, split=True, overlap=self.overlap, progress=False, ) # sources: [1, num_sources, 2, N] stems = {} for idx, name in enumerate(self._model.sources): stems[name] = sources[0, idx].cpu() # [2, N] return stems def extract_vocals( self, audio_path: str, target_sr: int = 16000, mono: bool = True, ) -> tuple[np.ndarray, int]: """ Extract vocals and prepare for ASR. Args: audio_path: Path to audio file target_sr: Target sample rate for ASR (16000 for Whisper) mono: Convert to mono (required by most ASR models) Returns: (vocals_array, sample_rate) — numpy float32 array ready for ASR """ stems = self.separate(audio_path) vocals = stems["vocals"] # [2, N] at 44100 Hz if mono: vocals = vocals.mean(dim=0) # [N] # Resample to target SR if self.sample_rate != target_sr: if vocals.dim() == 1: vocals = vocals.unsqueeze(0) vocals = torchaudio.functional.resample(vocals, self.sample_rate, target_sr) if mono: vocals = vocals.squeeze(0) return vocals.numpy().astype(np.float32), target_sr def extract_vocals_full_rate(self, audio_path: str) -> tuple[np.ndarray, int]: """ Extract vocals at full 44100 Hz for onset/offset analysis. Returns: (vocals_mono_array, 44100) — numpy float32 at native rate """ stems = self.separate(audio_path) vocals = stems["vocals"].mean(dim=0) # [N] mono at 44100 return vocals.numpy().astype(np.float32), self.sample_rate