lyric-sync / lyric_sync /separate.py
rikhoffbauer2's picture
Upload lyric_sync/separate.py
4b10521 verified
"""
Vocal stem separation using Demucs (HTDemucs).
Extracts clean vocals from a mixed audio track for downstream transcription.
Uses htdemucs_ft (fine-tuned) for best quality (~9.2 dB SDR on MUSDB18-HQ).
"""
import logging
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torchaudio
logger = logging.getLogger(__name__)
class VocalSeparator:
"""
Separate vocals from mixed audio using Demucs HTDemucs model.
The separated vocals are significantly cleaner for ASR than the original mix,
reducing transcription WER by ~3-5% (per arxiv:2506.15514).
"""
def __init__(
self,
model_name: str = "htdemucs_ft",
device: Optional[str] = None,
segment_seconds: float = 7.8,
overlap: float = 0.25,
shifts: int = 1,
):
"""
Args:
model_name: Demucs model to use. Options:
- "htdemucs_ft": Best quality, per-source fine-tuned (~9.2 dB SDR)
- "htdemucs": Base model, slightly faster download (~8.7 dB SDR)
- "htdemucs_6s": 6-stem (adds guitar, piano)
device: "cuda", "cpu", or "mps". Auto-detected if None.
segment_seconds: Processing chunk size. Lower = less VRAM.
- 7.8: Default (matches training), ~4-6 GB VRAM
- 4.0: For 8 GB GPUs
- 2.0: For CPU processing
overlap: Overlap ratio between chunks (0.25 = 25%, matches paper).
shifts: Test-time shift augmentation. 1=disabled, 5-10=better quality but N× slower.
"""
self.model_name = model_name
self.segment_seconds = segment_seconds
self.overlap = overlap
self.shifts = shifts
if device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = device
self._model = None
self._separator = None
def _load_model(self):
"""Lazy-load model on first use."""
if self._model is not None:
return
try:
# Try high-level Separator API first (demucs >= 4.1)
from demucs.api import Separator
self._separator = Separator(
model=self.model_name,
device=self.device,
segment=self.segment_seconds,
overlap=self.overlap,
)
logger.info(f"Loaded Demucs via Separator API: {self.model_name} on {self.device}")
except ImportError:
# Fallback to low-level API
from demucs.pretrained import get_model
self._model = get_model(self.model_name)
self._model.eval()
self._model.to(self.device)
logger.info(f"Loaded Demucs via low-level API: {self.model_name} on {self.device}")
@property
def sample_rate(self) -> int:
"""Demucs native sample rate (always 44100)."""
return 44100
def separate(self, audio_path: str) -> dict[str, torch.Tensor]:
"""
Separate audio into stems.
Args:
audio_path: Path to audio file (any format supported by torchaudio)
Returns:
Dict mapping stem name → tensor [channels, samples] at 44100 Hz.
Keys: "drums", "bass", "other", "vocals"
"""
self._load_model()
# Load audio
wav, sr = torchaudio.load(audio_path)
# Resample to model's native 44100 Hz
if sr != self.sample_rate:
wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
# Ensure stereo (Demucs expects 2-channel)
if wav.shape[0] == 1:
wav = wav.repeat(2, 1)
elif wav.shape[0] > 2:
wav = wav[:2] # Take first 2 channels
if self._separator is not None:
# High-level API
_, stems = self._separator.separate_tensor(wav.to(self.device))
return stems
else:
# Low-level API
from demucs.apply import apply_model
wav_batch = wav.unsqueeze(0).to(self.device) # [1, 2, N]
with torch.no_grad():
sources = apply_model(
self._model,
wav_batch,
device=self.device,
shifts=self.shifts,
split=True,
overlap=self.overlap,
progress=False,
)
# sources: [1, num_sources, 2, N]
stems = {}
for idx, name in enumerate(self._model.sources):
stems[name] = sources[0, idx].cpu() # [2, N]
return stems
def extract_vocals(
self,
audio_path: str,
target_sr: int = 16000,
mono: bool = True,
) -> tuple[np.ndarray, int]:
"""
Extract vocals and prepare for ASR.
Args:
audio_path: Path to audio file
target_sr: Target sample rate for ASR (16000 for Whisper)
mono: Convert to mono (required by most ASR models)
Returns:
(vocals_array, sample_rate) — numpy float32 array ready for ASR
"""
stems = self.separate(audio_path)
vocals = stems["vocals"] # [2, N] at 44100 Hz
if mono:
vocals = vocals.mean(dim=0) # [N]
# Resample to target SR
if self.sample_rate != target_sr:
if vocals.dim() == 1:
vocals = vocals.unsqueeze(0)
vocals = torchaudio.functional.resample(vocals, self.sample_rate, target_sr)
if mono:
vocals = vocals.squeeze(0)
return vocals.numpy().astype(np.float32), target_sr
def extract_vocals_full_rate(self, audio_path: str) -> tuple[np.ndarray, int]:
"""
Extract vocals at full 44100 Hz for onset/offset analysis.
Returns:
(vocals_mono_array, 44100) — numpy float32 at native rate
"""
stems = self.separate(audio_path)
vocals = stems["vocals"].mean(dim=0) # [N] mono at 44100
return vocals.numpy().astype(np.float32), self.sample_rate