rikhoffbauer2 commited on
Commit
4b10521
·
verified ·
1 Parent(s): 8041e59

Upload lyric_sync/separate.py

Browse files
Files changed (1) hide show
  1. lyric_sync/separate.py +189 -0
lyric_sync/separate.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vocal stem separation using Demucs (HTDemucs).
3
+
4
+ Extracts clean vocals from a mixed audio track for downstream transcription.
5
+ Uses htdemucs_ft (fine-tuned) for best quality (~9.2 dB SDR on MUSDB18-HQ).
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class VocalSeparator:
20
+ """
21
+ Separate vocals from mixed audio using Demucs HTDemucs model.
22
+
23
+ The separated vocals are significantly cleaner for ASR than the original mix,
24
+ reducing transcription WER by ~3-5% (per arxiv:2506.15514).
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_name: str = "htdemucs_ft",
30
+ device: Optional[str] = None,
31
+ segment_seconds: float = 7.8,
32
+ overlap: float = 0.25,
33
+ shifts: int = 1,
34
+ ):
35
+ """
36
+ Args:
37
+ model_name: Demucs model to use. Options:
38
+ - "htdemucs_ft": Best quality, per-source fine-tuned (~9.2 dB SDR)
39
+ - "htdemucs": Base model, slightly faster download (~8.7 dB SDR)
40
+ - "htdemucs_6s": 6-stem (adds guitar, piano)
41
+ device: "cuda", "cpu", or "mps". Auto-detected if None.
42
+ segment_seconds: Processing chunk size. Lower = less VRAM.
43
+ - 7.8: Default (matches training), ~4-6 GB VRAM
44
+ - 4.0: For 8 GB GPUs
45
+ - 2.0: For CPU processing
46
+ overlap: Overlap ratio between chunks (0.25 = 25%, matches paper).
47
+ shifts: Test-time shift augmentation. 1=disabled, 5-10=better quality but N× slower.
48
+ """
49
+ self.model_name = model_name
50
+ self.segment_seconds = segment_seconds
51
+ self.overlap = overlap
52
+ self.shifts = shifts
53
+
54
+ if device is None:
55
+ if torch.cuda.is_available():
56
+ self.device = "cuda"
57
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
58
+ self.device = "mps"
59
+ else:
60
+ self.device = "cpu"
61
+ else:
62
+ self.device = device
63
+
64
+ self._model = None
65
+ self._separator = None
66
+
67
+ def _load_model(self):
68
+ """Lazy-load model on first use."""
69
+ if self._model is not None:
70
+ return
71
+
72
+ try:
73
+ # Try high-level Separator API first (demucs >= 4.1)
74
+ from demucs.api import Separator
75
+ self._separator = Separator(
76
+ model=self.model_name,
77
+ device=self.device,
78
+ segment=self.segment_seconds,
79
+ overlap=self.overlap,
80
+ )
81
+ logger.info(f"Loaded Demucs via Separator API: {self.model_name} on {self.device}")
82
+ except ImportError:
83
+ # Fallback to low-level API
84
+ from demucs.pretrained import get_model
85
+ self._model = get_model(self.model_name)
86
+ self._model.eval()
87
+ self._model.to(self.device)
88
+ logger.info(f"Loaded Demucs via low-level API: {self.model_name} on {self.device}")
89
+
90
+ @property
91
+ def sample_rate(self) -> int:
92
+ """Demucs native sample rate (always 44100)."""
93
+ return 44100
94
+
95
+ def separate(self, audio_path: str) -> dict[str, torch.Tensor]:
96
+ """
97
+ Separate audio into stems.
98
+
99
+ Args:
100
+ audio_path: Path to audio file (any format supported by torchaudio)
101
+
102
+ Returns:
103
+ Dict mapping stem name → tensor [channels, samples] at 44100 Hz.
104
+ Keys: "drums", "bass", "other", "vocals"
105
+ """
106
+ self._load_model()
107
+
108
+ # Load audio
109
+ wav, sr = torchaudio.load(audio_path)
110
+
111
+ # Resample to model's native 44100 Hz
112
+ if sr != self.sample_rate:
113
+ wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
114
+
115
+ # Ensure stereo (Demucs expects 2-channel)
116
+ if wav.shape[0] == 1:
117
+ wav = wav.repeat(2, 1)
118
+ elif wav.shape[0] > 2:
119
+ wav = wav[:2] # Take first 2 channels
120
+
121
+ if self._separator is not None:
122
+ # High-level API
123
+ _, stems = self._separator.separate_tensor(wav.to(self.device))
124
+ return stems
125
+ else:
126
+ # Low-level API
127
+ from demucs.apply import apply_model
128
+
129
+ wav_batch = wav.unsqueeze(0).to(self.device) # [1, 2, N]
130
+
131
+ with torch.no_grad():
132
+ sources = apply_model(
133
+ self._model,
134
+ wav_batch,
135
+ device=self.device,
136
+ shifts=self.shifts,
137
+ split=True,
138
+ overlap=self.overlap,
139
+ progress=False,
140
+ )
141
+ # sources: [1, num_sources, 2, N]
142
+ stems = {}
143
+ for idx, name in enumerate(self._model.sources):
144
+ stems[name] = sources[0, idx].cpu() # [2, N]
145
+ return stems
146
+
147
+ def extract_vocals(
148
+ self,
149
+ audio_path: str,
150
+ target_sr: int = 16000,
151
+ mono: bool = True,
152
+ ) -> tuple[np.ndarray, int]:
153
+ """
154
+ Extract vocals and prepare for ASR.
155
+
156
+ Args:
157
+ audio_path: Path to audio file
158
+ target_sr: Target sample rate for ASR (16000 for Whisper)
159
+ mono: Convert to mono (required by most ASR models)
160
+
161
+ Returns:
162
+ (vocals_array, sample_rate) — numpy float32 array ready for ASR
163
+ """
164
+ stems = self.separate(audio_path)
165
+ vocals = stems["vocals"] # [2, N] at 44100 Hz
166
+
167
+ if mono:
168
+ vocals = vocals.mean(dim=0) # [N]
169
+
170
+ # Resample to target SR
171
+ if self.sample_rate != target_sr:
172
+ if vocals.dim() == 1:
173
+ vocals = vocals.unsqueeze(0)
174
+ vocals = torchaudio.functional.resample(vocals, self.sample_rate, target_sr)
175
+ if mono:
176
+ vocals = vocals.squeeze(0)
177
+
178
+ return vocals.numpy().astype(np.float32), target_sr
179
+
180
+ def extract_vocals_full_rate(self, audio_path: str) -> tuple[np.ndarray, int]:
181
+ """
182
+ Extract vocals at full 44100 Hz for onset/offset analysis.
183
+
184
+ Returns:
185
+ (vocals_mono_array, 44100) — numpy float32 at native rate
186
+ """
187
+ stems = self.separate(audio_path)
188
+ vocals = stems["vocals"].mean(dim=0) # [N] mono at 44100
189
+ return vocals.numpy().astype(np.float32), self.sample_rate