File size: 6,433 Bytes
4b10521 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
Vocal stem separation using Demucs (HTDemucs).
Extracts clean vocals from a mixed audio track for downstream transcription.
Uses htdemucs_ft (fine-tuned) for best quality (~9.2 dB SDR on MUSDB18-HQ).
"""
import logging
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torchaudio
logger = logging.getLogger(__name__)
class VocalSeparator:
"""
Separate vocals from mixed audio using Demucs HTDemucs model.
The separated vocals are significantly cleaner for ASR than the original mix,
reducing transcription WER by ~3-5% (per arxiv:2506.15514).
"""
def __init__(
self,
model_name: str = "htdemucs_ft",
device: Optional[str] = None,
segment_seconds: float = 7.8,
overlap: float = 0.25,
shifts: int = 1,
):
"""
Args:
model_name: Demucs model to use. Options:
- "htdemucs_ft": Best quality, per-source fine-tuned (~9.2 dB SDR)
- "htdemucs": Base model, slightly faster download (~8.7 dB SDR)
- "htdemucs_6s": 6-stem (adds guitar, piano)
device: "cuda", "cpu", or "mps". Auto-detected if None.
segment_seconds: Processing chunk size. Lower = less VRAM.
- 7.8: Default (matches training), ~4-6 GB VRAM
- 4.0: For 8 GB GPUs
- 2.0: For CPU processing
overlap: Overlap ratio between chunks (0.25 = 25%, matches paper).
shifts: Test-time shift augmentation. 1=disabled, 5-10=better quality but N× slower.
"""
self.model_name = model_name
self.segment_seconds = segment_seconds
self.overlap = overlap
self.shifts = shifts
if device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = device
self._model = None
self._separator = None
def _load_model(self):
"""Lazy-load model on first use."""
if self._model is not None:
return
try:
# Try high-level Separator API first (demucs >= 4.1)
from demucs.api import Separator
self._separator = Separator(
model=self.model_name,
device=self.device,
segment=self.segment_seconds,
overlap=self.overlap,
)
logger.info(f"Loaded Demucs via Separator API: {self.model_name} on {self.device}")
except ImportError:
# Fallback to low-level API
from demucs.pretrained import get_model
self._model = get_model(self.model_name)
self._model.eval()
self._model.to(self.device)
logger.info(f"Loaded Demucs via low-level API: {self.model_name} on {self.device}")
@property
def sample_rate(self) -> int:
"""Demucs native sample rate (always 44100)."""
return 44100
def separate(self, audio_path: str) -> dict[str, torch.Tensor]:
"""
Separate audio into stems.
Args:
audio_path: Path to audio file (any format supported by torchaudio)
Returns:
Dict mapping stem name → tensor [channels, samples] at 44100 Hz.
Keys: "drums", "bass", "other", "vocals"
"""
self._load_model()
# Load audio
wav, sr = torchaudio.load(audio_path)
# Resample to model's native 44100 Hz
if sr != self.sample_rate:
wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
# Ensure stereo (Demucs expects 2-channel)
if wav.shape[0] == 1:
wav = wav.repeat(2, 1)
elif wav.shape[0] > 2:
wav = wav[:2] # Take first 2 channels
if self._separator is not None:
# High-level API
_, stems = self._separator.separate_tensor(wav.to(self.device))
return stems
else:
# Low-level API
from demucs.apply import apply_model
wav_batch = wav.unsqueeze(0).to(self.device) # [1, 2, N]
with torch.no_grad():
sources = apply_model(
self._model,
wav_batch,
device=self.device,
shifts=self.shifts,
split=True,
overlap=self.overlap,
progress=False,
)
# sources: [1, num_sources, 2, N]
stems = {}
for idx, name in enumerate(self._model.sources):
stems[name] = sources[0, idx].cpu() # [2, N]
return stems
def extract_vocals(
self,
audio_path: str,
target_sr: int = 16000,
mono: bool = True,
) -> tuple[np.ndarray, int]:
"""
Extract vocals and prepare for ASR.
Args:
audio_path: Path to audio file
target_sr: Target sample rate for ASR (16000 for Whisper)
mono: Convert to mono (required by most ASR models)
Returns:
(vocals_array, sample_rate) — numpy float32 array ready for ASR
"""
stems = self.separate(audio_path)
vocals = stems["vocals"] # [2, N] at 44100 Hz
if mono:
vocals = vocals.mean(dim=0) # [N]
# Resample to target SR
if self.sample_rate != target_sr:
if vocals.dim() == 1:
vocals = vocals.unsqueeze(0)
vocals = torchaudio.functional.resample(vocals, self.sample_rate, target_sr)
if mono:
vocals = vocals.squeeze(0)
return vocals.numpy().astype(np.float32), target_sr
def extract_vocals_full_rate(self, audio_path: str) -> tuple[np.ndarray, int]:
"""
Extract vocals at full 44100 Hz for onset/offset analysis.
Returns:
(vocals_mono_array, 44100) — numpy float32 at native rate
"""
stems = self.separate(audio_path)
vocals = stems["vocals"].mean(dim=0) # [N] mono at 44100
return vocals.numpy().astype(np.float32), self.sample_rate
|