Richard Zhu commited on
Commit
94ce22b
·
1 Parent(s): b4f2105

Add LarsNet drum separator

Browse files
__pycache__/larsnet.cpython-311.pyc ADDED
Binary file (8.38 kB). View file
 
__pycache__/unet.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchaudio as ta
10
+ import soundfile as sf
11
+ import gradio as gr
12
+ from tqdm import tqdm
13
+ from typing import Union, Tuple, Optional
14
+ from torch import Tensor
15
+
16
+ from pyharp import build_endpoint, ModelCard
17
+
18
+
19
+ # ─────────────────────────────────────────────
20
+ # UNet Utilities
21
+ # ─────────────────────────────────────────────
22
+
23
+ class UNetUtils:
24
+ def __init__(self, F=None, T=None, n_fft=4096, win_length=None,
25
+ hop_length=None, center=True, device='cpu'):
26
+ self.n_fft = n_fft
27
+ self.win_length = n_fft if win_length is None else win_length
28
+ self.hop_length = self.win_length // 4 if hop_length is None else hop_length
29
+ self.hann_window = torch.hann_window(self.win_length, periodic=True).to(device)
30
+ self.center = center
31
+ self.device = device
32
+ self.F = F
33
+ self.T = T
34
+
35
+ def fold_unet_inputs(self, x):
36
+ time_dim = x.size(-1)
37
+ pad_len = math.ceil(time_dim / self.T) * self.T - time_dim
38
+ padded = F.pad(x, (0, pad_len))
39
+ if time_dim < self.T:
40
+ return padded
41
+ return torch.cat(torch.split(padded, self.T, dim=-1), dim=0)
42
+
43
+ def unfold_unet_outputs(self, x, input_size):
44
+ batch_size, n_frames = input_size[0], input_size[-1]
45
+ if x.size(0) == batch_size:
46
+ return x[..., :n_frames]
47
+ x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1)
48
+ return x[..., :n_frames]
49
+
50
+ def trim_freq_dim(self, x):
51
+ return x[..., :self.F, :]
52
+
53
+ def pad_freq_dim(self, x):
54
+ padding = (self.n_fft // 2 + 1) - x.size(-2)
55
+ return F.pad(x, (0, 0, 0, padding))
56
+
57
+ def pad_stft_input(self, x):
58
+ pad_len = (-(x.size(-1) - self.win_length) % self.hop_length) % self.win_length
59
+ return F.pad(x, (0, pad_len))
60
+
61
+ def _stft(self, x):
62
+ return torch.stft(input=x, n_fft=self.n_fft, window=self.hann_window,
63
+ win_length=self.win_length, hop_length=self.hop_length,
64
+ center=self.center, return_complex=True)
65
+
66
+ def _istft(self, x, trim_length=None):
67
+ return torch.istft(input=x, n_fft=self.n_fft, window=self.hann_window,
68
+ win_length=self.win_length, hop_length=self.hop_length,
69
+ center=self.center, length=trim_length)
70
+
71
+ def batch_stft(self, x, pad=True, return_complex=False):
72
+ x_shape = x.size()
73
+ x = x.reshape(-1, x_shape[-1])
74
+ if pad:
75
+ x = self.pad_stft_input(x)
76
+ S = self._stft(x)
77
+ S = S.reshape(x_shape[:-1] + S.shape[-2:])
78
+ if return_complex:
79
+ return S
80
+ return S.abs(), S.angle()
81
+
82
+ def batch_istft(self, magnitude, phase, trim_length=None):
83
+ S = torch.polar(magnitude, phase)
84
+ S_shape = S.size()
85
+ S = S.reshape(-1, S_shape[-2], S_shape[-1])
86
+ x = self._istft(S, trim_length)
87
+ return x.reshape(S_shape[:-2] + x.shape[-1:])
88
+
89
+
90
+ # ─────────────────────────────────────────────
91
+ # UNet Blocks
92
+ # ─────────────────────────────────────────────
93
+
94
+ class UNetEncoderBlock(nn.Module):
95
+ def __init__(self, in_channels, out_channels, kernel_size=(5,5),
96
+ stride=(2,2), padding=(2,2), relu_slope=0.2):
97
+ super().__init__()
98
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
99
+ stride=stride, padding=padding)
100
+ self.bn = nn.BatchNorm2d(out_channels)
101
+ self.activ = nn.LeakyReLU(relu_slope)
102
+ nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='leaky_relu', a=relu_slope)
103
+ nn.init.zeros_(self.conv.bias)
104
+
105
+ def forward(self, x):
106
+ c = self.conv(x)
107
+ return self.activ(self.bn(c)), c
108
+
109
+
110
+ class UNetDecoderBlock(nn.Module):
111
+ def __init__(self, in_channels, out_channels, kernel_size=(5,5),
112
+ stride=(2,2), padding=(2,2), output_padding=(1,1), dropout=0.0):
113
+ super().__init__()
114
+ self.conv_trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
115
+ stride=stride, padding=padding, output_padding=output_padding)
116
+ self.bn = nn.BatchNorm2d(out_channels)
117
+ self.dropout = nn.Dropout(dropout)
118
+ self.activ = nn.ReLU()
119
+
120
+ def forward(self, x):
121
+ return self.dropout(self.bn(self.activ(self.conv_trans(x))))
122
+
123
+
124
+ # ─────────────────────────────────────────────
125
+ # UNet Models
126
+ # ──��──────────────────────────────────────────
127
+
128
+ class UNet(nn.Module):
129
+ def __init__(self, input_size: Tuple[int, ...] = (2, 2048, 512),
130
+ power: float = 1.0, device: Optional[str] = None):
131
+ super().__init__()
132
+ self.input_size = input_size
133
+ audio_channels, f_size, t_size = input_size
134
+ self.utils = UNetUtils(F=f_size, T=t_size, device=device)
135
+ self.input_norm = nn.BatchNorm2d(f_size)
136
+ self.enc1 = UNetEncoderBlock(audio_channels, 16)
137
+ self.enc2 = UNetEncoderBlock(16, 32)
138
+ self.enc3 = UNetEncoderBlock(32, 64)
139
+ self.enc4 = UNetEncoderBlock(64, 128)
140
+ self.enc5 = UNetEncoderBlock(128, 256)
141
+ self.enc6 = UNetEncoderBlock(256, 512)
142
+ self.dec1 = UNetDecoderBlock(512, 256, dropout=0.5)
143
+ self.dec2 = UNetDecoderBlock(512, 128, dropout=0.5)
144
+ self.dec3 = UNetDecoderBlock(256, 64, dropout=0.5)
145
+ self.dec4 = UNetDecoderBlock(128, 32)
146
+ self.dec5 = UNetDecoderBlock(64, 16)
147
+ self.dec6 = UNetDecoderBlock(32, audio_channels)
148
+ self.mask_layer = nn.Sequential(
149
+ nn.Conv2d(audio_channels, audio_channels, kernel_size=(4,4), dilation=(2,2), padding=3),
150
+ nn.Sigmoid()
151
+ )
152
+ nn.init.kaiming_uniform_(self.mask_layer[0].weight)
153
+ nn.init.zeros_(self.mask_layer[0].bias)
154
+ if device is not None:
155
+ self.to(device)
156
+
157
+ def produce_mask(self, x: Tensor) -> Tensor:
158
+ x = self.input_norm(x.transpose(1, 2)).transpose(1, 2)
159
+ d, c1 = self.enc1(x)
160
+ d, c2 = self.enc2(d)
161
+ d, c3 = self.enc3(d)
162
+ d, c4 = self.enc4(d)
163
+ d, c5 = self.enc5(d)
164
+ _, c6 = self.enc6(d)
165
+ u = self.dec1(c6)
166
+ u = self.dec2(torch.cat([c5, u], dim=1))
167
+ u = self.dec3(torch.cat([c4, u], dim=1))
168
+ u = self.dec4(torch.cat([c3, u], dim=1))
169
+ u = self.dec5(torch.cat([c2, u], dim=1))
170
+ u = self.dec6(torch.cat([c1, u], dim=1))
171
+ return self.mask_layer(u)
172
+
173
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
174
+ input_size = x.size()
175
+ x = self.utils.fold_unet_inputs(x)
176
+ i = self.utils.trim_freq_dim(x)
177
+ mask = self.produce_mask(i)
178
+ mask = self.utils.pad_freq_dim(mask)
179
+ return (self.utils.unfold_unet_outputs(x * mask, input_size),
180
+ self.utils.unfold_unet_outputs(mask, input_size))
181
+
182
+
183
+ class UNetWaveform(UNet):
184
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
185
+ if x.dim() == 1:
186
+ x = x.repeat(2, 1)
187
+ if x.dim() == 2:
188
+ x = x.unsqueeze(0)
189
+ mag, phase = self.utils.batch_stft(x)
190
+ mag_hat, mask = super().forward(mag)
191
+ return self.utils.batch_istft(mag_hat, phase, trim_length=x.size(-1)), mask
192
+
193
+
194
+ # ─────────────────────────────────────────────
195
+ # LarsNet
196
+ # ─────────────────────────────────────────────
197
+
198
+ class LarsNet(nn.Module):
199
+ def __init__(self, wiener_filter=False, wiener_exponent=1.0,
200
+ config: Union[str, Path] = "config.yaml",
201
+ return_stft=False, device='cpu', **kwargs):
202
+ super().__init__(**kwargs)
203
+ with open(config, "r") as f:
204
+ config = yaml.safe_load(f)
205
+
206
+ self.device = device
207
+ self.wiener_filter = wiener_filter
208
+ self.wiener_exponent = wiener_exponent
209
+ self.return_stft = return_stft
210
+ self.stems = config['inference_models'].keys()
211
+ self.utils = UNetUtils(device=self.device)
212
+ self.sr = config['global']['sr']
213
+ self.models = {}
214
+
215
+ print('Loading UNet models...')
216
+ for stem in tqdm(self.stems):
217
+ checkpoint_path = Path(config['inference_models'][stem])
218
+ F = config[stem]['F']
219
+ T = config[stem]['T']
220
+ model = (UNet if (wiener_filter or return_stft) else UNetWaveform)(
221
+ input_size=(2, F, T), device=self.device
222
+ )
223
+ checkpoint = torch.load(str(checkpoint_path), map_location=device)
224
+ model.load_state_dict(checkpoint['model_state_dict'])
225
+ model.eval()
226
+ self.models[stem] = model
227
+
228
+ @staticmethod
229
+ def _fix_dim(x):
230
+ if x.dim() == 1:
231
+ x = x.repeat(2, 1)
232
+ if x.dim() == 2:
233
+ x = x.unsqueeze(0)
234
+ return x
235
+
236
+ def separate(self, x):
237
+ out = {}
238
+ x = x.to(self.device)
239
+ for stem, model in tqdm(self.models.items()):
240
+ y, _ = model(x)
241
+ out[stem] = y.squeeze(0).detach()
242
+ return out
243
+
244
+ def separate_wiener(self, x):
245
+ out = {}
246
+ mag_pred = []
247
+ x = self._fix_dim(x).to(self.device)
248
+ mag, phase = self.utils.batch_stft(x)
249
+ for stem, model in tqdm(self.models.items()):
250
+ _, mask = model(mag)
251
+ mag_pred.append((mask * mag) ** self.wiener_exponent)
252
+ pred_sum = sum(mag_pred)
253
+ for stem, pred in zip(self.stems, mag_pred):
254
+ wiener_mask = pred / (pred_sum + 1e-7)
255
+ y = self.utils.batch_istft(mag * wiener_mask, phase, trim_length=x.size(-1))
256
+ out[stem] = y.squeeze(0).detach()
257
+ return out
258
+
259
+ def separate_stft(self, x):
260
+ out = {}
261
+ x = self._fix_dim(x).to(self.device)
262
+ mag, phase = self.utils.batch_stft(x)
263
+ for stem, model in tqdm(self.models.items()):
264
+ mag_pred, _ = model(mag)
265
+ out[stem] = torch.polar(mag_pred, phase).squeeze(0).detach()
266
+ return out
267
+
268
+ def forward(self, x):
269
+ if isinstance(x, (str, Path)):
270
+ x, sr_ = ta.load(str(x))
271
+ if sr_ != self.sr:
272
+ x = ta.functional.resample(x, sr_, self.sr)
273
+ if self.return_stft:
274
+ return self.separate_stft(x)
275
+ elif self.wiener_filter:
276
+ return self.separate_wiener(x)
277
+ else:
278
+ return self.separate(x)
279
+
280
+
281
+ # ─────────────────────────────────────────────
282
+ # App
283
+ # ─────────────────────────────────────────────
284
+
285
+ model_card = ModelCard(
286
+ name="LarsNet Drum Stem Separator",
287
+ description="Separates a drum mix into individual drum stems: Kick, Snare, Toms, Hi-Hat, and Cymbals.",
288
+ author="A. I. Mezza, et al.",
289
+ tags=["drums", "demucs", "source-separation", "pyharp", "stems", "multi-output"],
290
+ )
291
+
292
+ MODEL = LarsNet(wiener_filter=False, device="cpu", config="config.yaml")
293
+
294
+
295
+ @torch.inference_mode()
296
+ def process_fn(audio_path: str):
297
+ stems = MODEL(audio_path)
298
+ output_dir = Path("outputs")
299
+ output_dir.mkdir(exist_ok=True)
300
+ output_paths = []
301
+ for stem_name in ["kick", "snare", "toms", "hihat", "cymbals"]:
302
+ out_path = output_dir / f"{stem_name}.wav"
303
+ sf.write(out_path, stems[stem_name].cpu().numpy().T, MODEL.sr)
304
+ output_paths.append(str(out_path))
305
+ return tuple(output_paths)
306
+
307
+
308
+ with gr.Blocks() as demo:
309
+ input_audio = gr.Audio(type="filepath", label="Drum Mix (Input)").harp_required(True)
310
+ output_kick = gr.Audio(type="filepath", label="Kick")
311
+ output_snare = gr.Audio(type="filepath", label="Snare")
312
+ output_toms = gr.Audio(type="filepath", label="Toms")
313
+ output_hihat = gr.Audio(type="filepath", label="Hi-Hat")
314
+ output_cymbals = gr.Audio(type="filepath", label="Cymbals")
315
+
316
+ app = build_endpoint(
317
+ model_card=model_card,
318
+ input_components=[input_audio],
319
+ output_components=[output_kick, output_snare, output_toms, output_hihat, output_cymbals],
320
+ process_fn=process_fn,
321
+ )
322
+
323
+ demo.queue().launch(show_error=True, share=True)
config.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ sr: 44100 # Hz
3
+ segment: 11.85 # seconds
4
+ shift: 2 # seconds
5
+ sample_rate: 44100 # Hz
6
+ n_workers: 16
7
+ prefetch_factor: 6
8
+
9
+
10
+ inference_models:
11
+ kick: 'pretrained_larsnet_models/kick/pretrained_kick_unet.pth'
12
+ snare: 'pretrained_larsnet_models/snare/pretrained_snare_unet.pth'
13
+ toms: 'pretrained_larsnet_models/toms/pretrained_toms_unet.pth'
14
+ hihat: 'pretrained_larsnet_models/hihat/pretrained_hihat_unet.pth'
15
+ cymbals: 'pretrained_larsnet_models/cymbals/pretrained_cymbals_unet.pth'
16
+
17
+
18
+ data_augmentation:
19
+ augmentation_prob: 0.5
20
+ kit_swap_augment_prob: 0.5
21
+ doubling_augment_prob: 0.3
22
+ pitch_shift_augment_prob: 0.3
23
+ saturation_augment_prob: 0.3
24
+ channel_swap_augment_prob: 0.5
25
+ remix_augment_prob: 0.3
26
+
27
+
28
+ kick:
29
+ F: 2048
30
+ T: 512
31
+ batch_size: 24
32
+ learning_rate: 1e-4
33
+ epochs: 22
34
+ training_mode: 'stft'
35
+ model_id: 'default_kick_unet'
36
+
37
+
38
+ snare:
39
+ F: 2048
40
+ T: 512
41
+ batch_size: 24
42
+ learning_rate: 1e-4
43
+ epochs: 22
44
+ training_mode: 'stft'
45
+ model_id: 'default_snare_unet'
46
+
47
+
48
+ toms:
49
+ F: 2048
50
+ T: 512
51
+ batch_size: 24
52
+ learning_rate: 1e-4
53
+ epochs: 22
54
+ training_mode: 'stft'
55
+ model_id: 'default_toms_unet'
56
+
57
+
58
+ hihat:
59
+ F: 2048
60
+ T: 512
61
+ batch_size: 24
62
+ learning_rate: 1e-4
63
+ epochs: 22
64
+ training_mode: 'stft'
65
+ model_id: 'default_hihat_unet'
66
+
67
+
68
+ cymbals:
69
+ F: 2048
70
+ T: 512
71
+ batch_size: 24
72
+ learning_rate: 1e-4
73
+ epochs: 22
74
+ training_mode: 'stft'
75
+ model_id: 'default_cymbals_unet'
pretrained_larsnet_models/.DS_Store ADDED
Binary file (8.2 kB). View file
 
pretrained_larsnet_models/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+ Direct download link: https://drive.google.com/uc?id=1U8-5924B1ii1cjv9p0MTPzayb00P4qoL&export=download
pretrained_larsnet_models/LICENSE.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The present LarsNet model checkpoints © 2023 by Alessandro Ilic Mezza (Image and Sound Processing Lab, Dipartimento di Elettronica, Informazione e Bioingegneria, Politecnico di Milano, Milan, Italy) are licensed under Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
pretrained_larsnet_models/cymbals/pretrained_cymbals_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:889804c465e2c6fdbbe45febbd4daafef01d8d22f9f34fb968695ab9793858d0
3
+ size 118037828
pretrained_larsnet_models/hihat/pretrained_hihat_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac004ac24a26e77f6d39671ed8ba45c3f476e956900d469ebb402386dad11dd7
3
+ size 118037828
pretrained_larsnet_models/kick/pretrained_kick_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed821b6a69b1ef0413ac9ef7958ac9a37e5f4f056e66f0191496bf903ac2628d
3
+ size 118037828
pretrained_larsnet_models/snare/pretrained_snare_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78bd75001ff6c52b23de6245cecd1606adbd45348e0216f6d8c116553013e4fd
3
+ size 118037828
pretrained_larsnet_models/toms/pretrained_toms_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:699112983948e14d805890a0723e5b402203a3a1db1cd0ccdd8a80058062ef77
3
+ size 118037828
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pip install -r requirements.txt
2
+ -e git+https://github.com/TEAMuP-dev/pyharp.git#egg=pyharp
3
+ torch==2.7
4
+ torchaudio==2.7
5
+ pyyaml
6
+ tqdm
7
+ soundfile
8
+ torchcodec==0.7