| """ |
| Vocal stem separation using Demucs (HTDemucs). |
| |
| Extracts clean vocals from a mixed audio track for downstream transcription. |
| Uses htdemucs_ft (fine-tuned) for best quality (~9.2 dB SDR on MUSDB18-HQ). |
| """ |
|
|
| import logging |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class VocalSeparator: |
| """ |
| Separate vocals from mixed audio using Demucs HTDemucs model. |
| |
| The separated vocals are significantly cleaner for ASR than the original mix, |
| reducing transcription WER by ~3-5% (per arxiv:2506.15514). |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "htdemucs_ft", |
| device: Optional[str] = None, |
| segment_seconds: float = 7.8, |
| overlap: float = 0.25, |
| shifts: int = 1, |
| ): |
| """ |
| Args: |
| model_name: Demucs model to use. Options: |
| - "htdemucs_ft": Best quality, per-source fine-tuned (~9.2 dB SDR) |
| - "htdemucs": Base model, slightly faster download (~8.7 dB SDR) |
| - "htdemucs_6s": 6-stem (adds guitar, piano) |
| device: "cuda", "cpu", or "mps". Auto-detected if None. |
| segment_seconds: Processing chunk size. Lower = less VRAM. |
| - 7.8: Default (matches training), ~4-6 GB VRAM |
| - 4.0: For 8 GB GPUs |
| - 2.0: For CPU processing |
| overlap: Overlap ratio between chunks (0.25 = 25%, matches paper). |
| shifts: Test-time shift augmentation. 1=disabled, 5-10=better quality but N× slower. |
| """ |
| self.model_name = model_name |
| self.segment_seconds = segment_seconds |
| self.overlap = overlap |
| self.shifts = shifts |
|
|
| if device is None: |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| self.device = "mps" |
| else: |
| self.device = "cpu" |
| else: |
| self.device = device |
|
|
| self._model = None |
| self._separator = None |
|
|
| def _load_model(self): |
| """Lazy-load model on first use.""" |
| if self._model is not None: |
| return |
|
|
| try: |
| |
| from demucs.api import Separator |
| self._separator = Separator( |
| model=self.model_name, |
| device=self.device, |
| segment=self.segment_seconds, |
| overlap=self.overlap, |
| ) |
| logger.info(f"Loaded Demucs via Separator API: {self.model_name} on {self.device}") |
| except ImportError: |
| |
| from demucs.pretrained import get_model |
| self._model = get_model(self.model_name) |
| self._model.eval() |
| self._model.to(self.device) |
| logger.info(f"Loaded Demucs via low-level API: {self.model_name} on {self.device}") |
|
|
| @property |
| def sample_rate(self) -> int: |
| """Demucs native sample rate (always 44100).""" |
| return 44100 |
|
|
| def separate(self, audio_path: str) -> dict[str, torch.Tensor]: |
| """ |
| Separate audio into stems. |
| |
| Args: |
| audio_path: Path to audio file (any format supported by torchaudio) |
| |
| Returns: |
| Dict mapping stem name → tensor [channels, samples] at 44100 Hz. |
| Keys: "drums", "bass", "other", "vocals" |
| """ |
| self._load_model() |
|
|
| |
| wav, sr = torchaudio.load(audio_path) |
|
|
| |
| if sr != self.sample_rate: |
| wav = torchaudio.functional.resample(wav, sr, self.sample_rate) |
|
|
| |
| if wav.shape[0] == 1: |
| wav = wav.repeat(2, 1) |
| elif wav.shape[0] > 2: |
| wav = wav[:2] |
|
|
| if self._separator is not None: |
| |
| _, stems = self._separator.separate_tensor(wav.to(self.device)) |
| return stems |
| else: |
| |
| from demucs.apply import apply_model |
|
|
| wav_batch = wav.unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| sources = apply_model( |
| self._model, |
| wav_batch, |
| device=self.device, |
| shifts=self.shifts, |
| split=True, |
| overlap=self.overlap, |
| progress=False, |
| ) |
| |
| stems = {} |
| for idx, name in enumerate(self._model.sources): |
| stems[name] = sources[0, idx].cpu() |
| return stems |
|
|
| def extract_vocals( |
| self, |
| audio_path: str, |
| target_sr: int = 16000, |
| mono: bool = True, |
| ) -> tuple[np.ndarray, int]: |
| """ |
| Extract vocals and prepare for ASR. |
| |
| Args: |
| audio_path: Path to audio file |
| target_sr: Target sample rate for ASR (16000 for Whisper) |
| mono: Convert to mono (required by most ASR models) |
| |
| Returns: |
| (vocals_array, sample_rate) — numpy float32 array ready for ASR |
| """ |
| stems = self.separate(audio_path) |
| vocals = stems["vocals"] |
|
|
| if mono: |
| vocals = vocals.mean(dim=0) |
| |
| |
| if self.sample_rate != target_sr: |
| if vocals.dim() == 1: |
| vocals = vocals.unsqueeze(0) |
| vocals = torchaudio.functional.resample(vocals, self.sample_rate, target_sr) |
| if mono: |
| vocals = vocals.squeeze(0) |
|
|
| return vocals.numpy().astype(np.float32), target_sr |
|
|
| def extract_vocals_full_rate(self, audio_path: str) -> tuple[np.ndarray, int]: |
| """ |
| Extract vocals at full 44100 Hz for onset/offset analysis. |
| |
| Returns: |
| (vocals_mono_array, 44100) — numpy float32 at native rate |
| """ |
| stems = self.separate(audio_path) |
| vocals = stems["vocals"].mean(dim=0) |
| return vocals.numpy().astype(np.float32), self.sample_rate |
|
|