Martin commited on
Commit
b3a1370
Β·
1 Parent(s): 0e88f8d

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import yaml
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+ import numpy as np
8
+
9
+ # Load configuration
10
+ with open('config.yaml', 'r') as f:
11
+ config = yaml.safe_load(f)
12
+
13
+ # Define model loading function
14
+ def load_model():
15
+ # Specify the repository and file path of your checkpoint
16
+ repo_id = "GaboxR67/MelBandRoformers" # Replace with actual repo
17
+ filename = "melbandroformers/instrumental/Inst_ExperimentalV1.ckpt" # Replace with actual path
18
+
19
+ # Download the checkpoint from Hugging Face
20
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
21
+
22
+ # Load the checkpoint
23
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
24
+
25
+ # Initialize your model here based on the MelBandRoformer architecture
26
+ # This part depends on the exact model implementation
27
+ # You'll need to import or define the MelBandRoformer class
28
+
29
+ return model # Return the loaded model
30
+
31
+ # Initialize model
32
+ model = load_model()
33
+
34
+ def separate_audio(audio_file):
35
+ """
36
+ Process audio file and separate instrumental/vocals
37
+ """
38
+ # Load audio
39
+ waveform, sample_rate = torchaudio.load(audio_file)
40
+
41
+ # Resample if necessary
42
+ if sample_rate != config['sample_rate']:
43
+ resampler = torchaudio.transforms.Resample(sample_rate, config['sample_rate'])
44
+ waveform = resampler(waveform)
45
+
46
+ # Process with model
47
+ with torch.no_grad():
48
+ # Add your inference code here
49
+ # This will depend on the exact model implementation
50
+ instrumental = model(waveform)
51
+
52
+ # Save output
53
+ output_path = "output_instrumental.wav"
54
+ torchaudio.save(output_path, instrumental, config['sample_rate'])
55
+
56
+ return output_path
57
+
58
+ # Create Gradio interface
59
+ iface = gr.Interface(
60
+ fn=separate_audio,
61
+ inputs=gr.Audio(type="filepath", label="Upload Audio"),
62
+ outputs=gr.Audio(label="Instrumental Output"),
63
+ title="MelBand Roformer Audio Separation",
64
+ description="Separate instrumental from vocals using MelBand Roformer model"
65
+ )
66
+
67
+ if __name__ == "__main__":
68
+ iface.launch()
config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chunk_size: 485100
2
+ dim_f: 1024
3
+ dim_t: 1101
4
+ hop_length: 441
5
+ n_fft: 2048
6
+ num_channels: 2
7
+ sample_rate: 44100
8
+ min_mean_abs: 0.000
9
+
10
+ model:
11
+ dim: 384
12
+ depth: 6
13
+ stereo: true
14
+ num_stems: 1
15
+ time_transformer_depth: 1
16
+ freq_transformer_depth: 1
17
+ num_bands: 60
18
+ dim_head: 64
19
+ heads: 8
20
+ attn_dropout: 0
21
+ ff_dropout: 0
22
+ flash_attn: True
23
+ dim_freqs_in: 1025
24
+ sample_rate: 44100
25
+ stft_n_fft: 2048
26
+ stft_hop_length: 441
27
+ stft_win_length: 2048
28
+ stft_normalized: False
29
+ mask_estimator_depth: 2
30
+ multi_stft_resolution_loss_weight: 1.0
31
+ multi_stft_resolutions_window_sizes: !!python/tuple
32
+ - 4096
33
+ - 2048
34
+ - 1024
35
+ - 512
36
+ - 256
37
+ multi_stft_hop_size: 147
38
+ multi_stft_normalized: False
39
+
40
+ training:
41
+ instruments:
42
+ - Instrumental
43
+ - Vocals
44
+ target_instrument: Instrumental
45
+ use_amp: True
46
+
47
+ inference:
48
+ batch_size: 1
49
+ dim_t: 1101
50
+ num_overlap: 2
models/mel_band_roformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.mel_band_roformer.mel_band_roformer import MelBandRoformer
models/mel_band_roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.mel_band_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+
17
+ from librosa import filters
18
+
19
+
20
+ # helper functions
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+
26
+ def default(v, d):
27
+ return v if exists(v) else d
28
+
29
+
30
+ def pack_one(t, pattern):
31
+ return pack([t], pattern)
32
+
33
+
34
+ def unpack_one(t, ps, pattern):
35
+ return unpack(t, ps, pattern)[0]
36
+
37
+
38
+ def pad_at_dim(t, pad, dim=-1, value=0.):
39
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
40
+ zeros = ((0, 0) * dims_from_right)
41
+ return F.pad(t, (*zeros, *pad), value=value)
42
+
43
+
44
+ # norm
45
+
46
+ class RMSNorm(Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.scale = dim ** 0.5
50
+ self.gamma = nn.Parameter(torch.ones(dim))
51
+
52
+ def forward(self, x):
53
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
54
+
55
+
56
+ # attention
57
+
58
+ class FeedForward(Module):
59
+ def __init__(
60
+ self,
61
+ dim,
62
+ mult=4,
63
+ dropout=0.
64
+ ):
65
+ super().__init__()
66
+ dim_inner = int(dim * mult)
67
+ self.net = nn.Sequential(
68
+ RMSNorm(dim),
69
+ nn.Linear(dim, dim_inner),
70
+ nn.GELU(),
71
+ nn.Dropout(dropout),
72
+ nn.Linear(dim_inner, dim),
73
+ nn.Dropout(dropout)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.net(x)
78
+
79
+
80
+ class Attention(Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ heads=8,
85
+ dim_head=64,
86
+ dropout=0.,
87
+ rotary_embed=None,
88
+ flash=True
89
+ ):
90
+ super().__init__()
91
+ self.heads = heads
92
+ self.scale = dim_head ** -0.5
93
+ dim_inner = heads * dim_head
94
+
95
+ self.rotary_embed = rotary_embed
96
+
97
+ self.attend = Attend(flash=flash, dropout=dropout)
98
+
99
+ self.norm = RMSNorm(dim)
100
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
101
+
102
+ self.to_gates = nn.Linear(dim, heads)
103
+
104
+ self.to_out = nn.Sequential(
105
+ nn.Linear(dim_inner, dim, bias=False),
106
+ nn.Dropout(dropout)
107
+ )
108
+
109
+ def forward(self, x):
110
+ x = self.norm(x)
111
+
112
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
113
+
114
+ if exists(self.rotary_embed):
115
+ q = self.rotary_embed.rotate_queries_or_keys(q)
116
+ k = self.rotary_embed.rotate_queries_or_keys(k)
117
+
118
+ out = self.attend(q, k, v)
119
+
120
+ gates = self.to_gates(x)
121
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
122
+
123
+ out = rearrange(out, 'b h n d -> b n (h d)')
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(Module):
128
+ def __init__(
129
+ self,
130
+ *,
131
+ dim,
132
+ depth,
133
+ dim_head=64,
134
+ heads=8,
135
+ attn_dropout=0.,
136
+ ff_dropout=0.,
137
+ ff_mult=4,
138
+ norm_output=True,
139
+ rotary_embed=None,
140
+ flash_attn=True
141
+ ):
142
+ super().__init__()
143
+ self.layers = ModuleList([])
144
+
145
+ for _ in range(depth):
146
+ self.layers.append(ModuleList([
147
+ Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed,
148
+ flash=flash_attn),
149
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
150
+ ]))
151
+
152
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
153
+
154
+ def forward(self, x):
155
+
156
+ for attn, ff in self.layers:
157
+ x = attn(x) + x
158
+ x = ff(x) + x
159
+
160
+ return self.norm(x)
161
+
162
+
163
+ # bandsplit module
164
+
165
+ class BandSplit(Module):
166
+ @beartype
167
+ def __init__(
168
+ self,
169
+ dim,
170
+ dim_inputs: Tuple[int, ...]
171
+ ):
172
+ super().__init__()
173
+ self.dim_inputs = dim_inputs
174
+ self.to_features = ModuleList([])
175
+
176
+ for dim_in in dim_inputs:
177
+ net = nn.Sequential(
178
+ RMSNorm(dim_in),
179
+ nn.Linear(dim_in, dim)
180
+ )
181
+
182
+ self.to_features.append(net)
183
+
184
+ def forward(self, x):
185
+ x = x.split(self.dim_inputs, dim=-1)
186
+
187
+ outs = []
188
+ for split_input, to_feature in zip(x, self.to_features):
189
+ split_output = to_feature(split_input)
190
+ outs.append(split_output)
191
+
192
+ return torch.stack(outs, dim=-2)
193
+
194
+
195
+ def MLP(
196
+ dim_in,
197
+ dim_out,
198
+ dim_hidden=None,
199
+ depth=1,
200
+ activation=nn.Tanh
201
+ ):
202
+ dim_hidden = default(dim_hidden, dim_in)
203
+
204
+ net = []
205
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
206
+
207
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
208
+ is_last = ind == (len(dims) - 2)
209
+
210
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
211
+
212
+ if is_last:
213
+ continue
214
+
215
+ net.append(activation())
216
+
217
+ return nn.Sequential(*net)
218
+
219
+
220
+ class MaskEstimator(Module):
221
+ @beartype
222
+ def __init__(
223
+ self,
224
+ dim,
225
+ dim_inputs: Tuple[int, ...],
226
+ depth,
227
+ mlp_expansion_factor=4
228
+ ):
229
+ super().__init__()
230
+ self.dim_inputs = dim_inputs
231
+ self.to_freqs = ModuleList([])
232
+ dim_hidden = dim * mlp_expansion_factor
233
+
234
+ for dim_in in dim_inputs:
235
+ net = []
236
+
237
+ mlp = nn.Sequential(
238
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
239
+ nn.GLU(dim=-1)
240
+ )
241
+
242
+ self.to_freqs.append(mlp)
243
+
244
+ def forward(self, x):
245
+ x = x.unbind(dim=-2)
246
+
247
+ outs = []
248
+
249
+ for band_features, mlp in zip(x, self.to_freqs):
250
+ freq_out = mlp(band_features)
251
+ outs.append(freq_out)
252
+
253
+ return torch.cat(outs, dim=-1)
254
+
255
+
256
+ # main class
257
+
258
+ class MelBandRoformer(Module):
259
+
260
+ @beartype
261
+ def __init__(
262
+ self,
263
+ dim,
264
+ *,
265
+ depth,
266
+ stereo=False,
267
+ num_stems=1,
268
+ time_transformer_depth=2,
269
+ freq_transformer_depth=2,
270
+ num_bands=60,
271
+ dim_head=64,
272
+ heads=8,
273
+ attn_dropout=0.1,
274
+ ff_dropout=0.1,
275
+ flash_attn=True,
276
+ dim_freqs_in=1025,
277
+ sample_rate=44100, # needed for mel filter bank from librosa
278
+ stft_n_fft=2048,
279
+ stft_hop_length=512,
280
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
281
+ stft_win_length=2048,
282
+ stft_normalized=False,
283
+ stft_window_fn: Optional[Callable] = None,
284
+ mask_estimator_depth=1,
285
+ multi_stft_resolution_loss_weight=1.,
286
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
287
+ multi_stft_hop_size=147,
288
+ multi_stft_normalized=False,
289
+ multi_stft_window_fn: Callable = torch.hann_window,
290
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
291
+ ):
292
+ super().__init__()
293
+
294
+ self.stereo = stereo
295
+ self.audio_channels = 2 if stereo else 1
296
+ self.num_stems = num_stems
297
+
298
+ self.layers = ModuleList([])
299
+
300
+ transformer_kwargs = dict(
301
+ dim=dim,
302
+ heads=heads,
303
+ dim_head=dim_head,
304
+ attn_dropout=attn_dropout,
305
+ ff_dropout=ff_dropout,
306
+ flash_attn=flash_attn
307
+ )
308
+
309
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
310
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
311
+
312
+ for _ in range(depth):
313
+ self.layers.append(nn.ModuleList([
314
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs),
315
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
316
+ ]))
317
+
318
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
319
+
320
+ self.stft_kwargs = dict(
321
+ n_fft=stft_n_fft,
322
+ hop_length=stft_hop_length,
323
+ win_length=stft_win_length,
324
+ normalized=stft_normalized
325
+ )
326
+
327
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
328
+
329
+ # create mel filter bank
330
+ # with librosa.filters.mel as in section 2 of paper
331
+
332
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
333
+
334
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
335
+
336
+ # for some reason, it doesn't include the first freq? just force a value for now
337
+
338
+ mel_filter_bank[0][0] = 1.
339
+
340
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
341
+ # so let's force a positive value
342
+
343
+ mel_filter_bank[-1, -1] = 1.
344
+
345
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
346
+
347
+ freqs_per_band = mel_filter_bank > 0
348
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
349
+
350
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
351
+ freq_indices = repeated_freq_indices[freqs_per_band]
352
+
353
+ if stereo:
354
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
355
+ freq_indices = freq_indices * 2 + torch.arange(2)
356
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
357
+
358
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
359
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
360
+
361
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
362
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
363
+
364
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
365
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
366
+
367
+ # band split and mask estimator
368
+
369
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
370
+
371
+ self.band_split = BandSplit(
372
+ dim=dim,
373
+ dim_inputs=freqs_per_bands_with_complex
374
+ )
375
+
376
+ self.mask_estimators = nn.ModuleList([])
377
+
378
+ for _ in range(num_stems):
379
+ mask_estimator = MaskEstimator(
380
+ dim=dim,
381
+ dim_inputs=freqs_per_bands_with_complex,
382
+ depth=mask_estimator_depth
383
+ )
384
+
385
+ self.mask_estimators.append(mask_estimator)
386
+
387
+ # for the multi-resolution stft loss
388
+
389
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
390
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
391
+ self.multi_stft_n_fft = stft_n_fft
392
+ self.multi_stft_window_fn = multi_stft_window_fn
393
+
394
+ self.multi_stft_kwargs = dict(
395
+ hop_length=multi_stft_hop_size,
396
+ normalized=multi_stft_normalized
397
+ )
398
+
399
+ self.match_input_audio_length = match_input_audio_length
400
+
401
+ def forward(
402
+ self,
403
+ raw_audio,
404
+ target=None,
405
+ return_loss_breakdown=False
406
+ ):
407
+ """
408
+ einops
409
+
410
+ b - batch
411
+ f - freq
412
+ t - time
413
+ s - audio channel (1 for mono, 2 for stereo)
414
+ n - number of 'stems'
415
+ c - complex (2)
416
+ d - feature dimension
417
+ """
418
+
419
+ device = raw_audio.device
420
+
421
+ if raw_audio.ndim == 2:
422
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
423
+
424
+ batch, channels, raw_audio_length = raw_audio.shape
425
+
426
+ istft_length = raw_audio_length if self.match_input_audio_length else None
427
+
428
+ assert (not self.stereo and channels == 1) or (
429
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
430
+
431
+ # to stft
432
+
433
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
434
+
435
+ stft_window = self.stft_window_fn(device=device)
436
+
437
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
438
+ stft_repr = torch.view_as_real(stft_repr)
439
+
440
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
441
+ stft_repr = rearrange(stft_repr,
442
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
443
+
444
+ # index out all frequencies for all frequency ranges across bands ascending in one go
445
+
446
+ batch_arange = torch.arange(batch, device=device)[..., None]
447
+
448
+ # account for stereo
449
+
450
+ x = stft_repr[batch_arange, self.freq_indices]
451
+
452
+ # fold the complex (real and imag) into the frequencies dimension
453
+
454
+ x = rearrange(x, 'b f t c -> b t (f c)')
455
+
456
+ x = self.band_split(x)
457
+
458
+ # axial / hierarchical attention
459
+
460
+ for time_transformer, freq_transformer in self.layers:
461
+ x = rearrange(x, 'b t f d -> b f t d')
462
+ x, ps = pack([x], '* t d')
463
+
464
+ x = time_transformer(x)
465
+
466
+ x, = unpack(x, ps, '* t d')
467
+ x = rearrange(x, 'b f t d -> b t f d')
468
+ x, ps = pack([x], '* f d')
469
+
470
+ x = freq_transformer(x)
471
+
472
+ x, = unpack(x, ps, '* f d')
473
+
474
+ num_stems = len(self.mask_estimators)
475
+
476
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
477
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
478
+
479
+ # modulate frequency representation
480
+
481
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
482
+
483
+ # complex number multiplication
484
+
485
+ stft_repr = torch.view_as_complex(stft_repr)
486
+ masks = torch.view_as_complex(masks)
487
+
488
+ masks = masks.type(stft_repr.dtype)
489
+
490
+ # need to average the estimated mask for the overlapped frequencies
491
+
492
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
493
+
494
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
495
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
496
+
497
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
498
+
499
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
500
+
501
+ # modulate stft repr with estimated mask
502
+
503
+ stft_repr = stft_repr * masks_averaged
504
+
505
+ # istft
506
+
507
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
508
+
509
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
510
+ length=istft_length)
511
+
512
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
513
+
514
+ if num_stems == 1:
515
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
516
+
517
+ # if a target is passed in, calculate loss for learning
518
+
519
+ if not exists(target):
520
+ return recon_audio
521
+
522
+ if self.num_stems > 1:
523
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
524
+
525
+ if target.ndim == 2:
526
+ target = rearrange(target, '... t -> ... 1 t')
527
+
528
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
529
+
530
+ loss = F.l1_loss(recon_audio, target)
531
+
532
+ multi_stft_resolution_loss = 0.
533
+
534
+ for window_size in self.multi_stft_resolutions_window_sizes:
535
+ res_stft_kwargs = dict(
536
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
537
+ win_length=window_size,
538
+ return_complex=True,
539
+ window=self.multi_stft_window_fn(window_size, device=device),
540
+ **self.multi_stft_kwargs,
541
+ )
542
+
543
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
544
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
545
+
546
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
547
+
548
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
549
+
550
+ total_loss = loss + weighted_multi_resolution_loss
551
+
552
+ if not return_loss_breakdown:
553
+ return total_loss
554
+
555
+ return total_loss, (loss, multi_stft_resolution_loss)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio
3
+ librosa
4
+ gradio
5
+ huggingface_hub
6
+ pyyaml
7
+ numpy
8
+ scipy
9
+ einops
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ import sys
5
+ import torch.nn as nn
6
+
7
+
8
+ def get_model_from_config(model_type, config):
9
+ if model_type == 'mel_band_roformer':
10
+ from models.mel_band_roformer import MelBandRoformer
11
+ model = MelBandRoformer(
12
+ **dict(config.model)
13
+ )
14
+ else:
15
+ print('Unknown model: {}'.format(model_type))
16
+ model = None
17
+
18
+ return model
19
+
20
+
21
+ def get_windowing_array(window_size, fade_size, device):
22
+ fadein = torch.linspace(0, 1, fade_size)
23
+ fadeout = torch.linspace(1, 0, fade_size)
24
+ window = torch.ones(window_size)
25
+ window[-fade_size:] *= fadeout
26
+ window[:fade_size] *= fadein
27
+ return window.to(device)
28
+
29
+ def demix_track(config, model, mix, device, first_chunk_time=None):
30
+ C = config.inference.chunk_size
31
+ N = config.inference.num_overlap
32
+ step = C // N
33
+ fade_size = C // 10
34
+ border = C - step
35
+
36
+ if mix.shape[1] > 2 * border and border > 0:
37
+ mix = nn.functional.pad(mix, (border, border), mode='reflect')
38
+
39
+ windowing_array = get_windowing_array(C, fade_size, device)
40
+
41
+ with torch.cuda.amp.autocast():
42
+ with torch.no_grad():
43
+ if config.training.target_instrument is not None:
44
+ req_shape = (1, ) + tuple(mix.shape)
45
+ else:
46
+ req_shape = (len(config.training.instruments),) + tuple(mix.shape)
47
+
48
+ mix = mix.to(device)
49
+ result = torch.zeros(req_shape, dtype=torch.float32).to(device)
50
+ counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
51
+
52
+ i = 0
53
+ total_length = mix.shape[1]
54
+ num_chunks = (total_length + step - 1) // step
55
+
56
+ if first_chunk_time is None:
57
+ start_time = time.time()
58
+ first_chunk = True
59
+ else:
60
+ start_time = None
61
+ first_chunk = False
62
+
63
+ while i < total_length:
64
+ part = mix[:, i:i + C]
65
+ length = part.shape[-1]
66
+ if length < C:
67
+ if length > C // 2 + 1:
68
+ part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
69
+ else:
70
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
71
+
72
+ if first_chunk and i == 0:
73
+ chunk_start_time = time.time()
74
+
75
+ x = model(part.unsqueeze(0))[0]
76
+
77
+ window = windowing_array.clone()
78
+ if i == 0:
79
+ window[:fade_size] = 1
80
+ elif i + C >= total_length:
81
+ window[-fade_size:] = 1
82
+
83
+ result[..., i:i+length] += x[..., :length] * window[..., :length]
84
+ counter[..., i:i+length] += window[..., :length]
85
+ i += step
86
+
87
+ if first_chunk and i == step:
88
+ chunk_time = time.time() - chunk_start_time
89
+ first_chunk_time = chunk_time
90
+ estimated_total_time = chunk_time * num_chunks
91
+ print(f"Estimated total processing time for this track: {estimated_total_time:.2f} seconds")
92
+ first_chunk = False
93
+
94
+ if first_chunk_time is not None and i > step:
95
+ chunks_processed = i // step
96
+ time_remaining = first_chunk_time * (num_chunks - chunks_processed)
97
+ sys.stdout.write(f"\rEstimated time remaining: {time_remaining:.2f} seconds")
98
+ sys.stdout.flush()
99
+
100
+ print()
101
+ estimated_sources = result / counter
102
+ estimated_sources = estimated_sources.cpu().numpy()
103
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
104
+
105
+ if mix.shape[1] > 2 * border and border > 0:
106
+ estimated_sources = estimated_sources[..., border:-border]
107
+
108
+ if config.training.target_instrument is None:
109
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}, first_chunk_time
110
+ else:
111
+ return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}, first_chunk_time