pcunwa commited on
Commit
bdb0f8a
·
verified ·
1 Parent(s): 17bb4d6

Upload bs_roformer.py

Browse files
Files changed (1) hide show
  1. bs_roformer.py +683 -0
bs_roformer.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+ try:
10
+ from models.bs_roformer.attend_sage import Attend as AttendSage
11
+ except:
12
+ pass
13
+ from torch.utils.checkpoint import checkpoint
14
+
15
+ from beartype.typing import Tuple, Optional, List, Callable
16
+ from beartype import beartype
17
+
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange
22
+ # helper functions
23
+
24
+ def exists(val):
25
+ return val is not None
26
+
27
+
28
+ def default(v, d):
29
+ return v if exists(v) else d
30
+
31
+
32
+ def pack_one(t, pattern):
33
+ return pack([t], pattern)
34
+
35
+
36
+ def unpack_one(t, ps, pattern):
37
+ return unpack(t, ps, pattern)[0]
38
+
39
+
40
+ # norm
41
+
42
+ def l2norm(t):
43
+ return F.normalize(t, dim = -1, p = 2)
44
+
45
+
46
+ class RMSNorm(Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.scale = dim ** 0.5
50
+ self.gamma = nn.Parameter(torch.ones(dim))
51
+
52
+ def forward(self, x):
53
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
54
+
55
+
56
+ # attention
57
+
58
+ class FeedForward(Module):
59
+ def __init__(
60
+ self,
61
+ dim,
62
+ mult=4,
63
+ dropout=0.
64
+ ):
65
+ super().__init__()
66
+ dim_inner = int(dim * mult)
67
+ self.net = nn.Sequential(
68
+ RMSNorm(dim),
69
+ nn.Linear(dim, dim_inner),
70
+ nn.GELU(),
71
+ nn.Dropout(dropout),
72
+ nn.Linear(dim_inner, dim),
73
+ nn.Dropout(dropout)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.net(x)
78
+
79
+ class Attention(Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ heads=8,
84
+ dim_head=64,
85
+ dropout=0.,
86
+ rotary_embed=None,
87
+ flash=True,
88
+ sage_attention=False,
89
+ ):
90
+ super().__init__()
91
+ self.heads = heads
92
+ self.scale = dim_head ** -0.5
93
+ dim_inner = heads * dim_head
94
+
95
+ self.rotary_embed = rotary_embed
96
+
97
+ if sage_attention:
98
+ self.attend = AttendSage(flash=flash, dropout=dropout)
99
+ else:
100
+ self.attend = Attend(flash=flash, dropout=dropout)
101
+
102
+ self.norm = RMSNorm(dim)
103
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
104
+
105
+ self.to_gates = nn.Linear(dim, heads)
106
+
107
+ self.to_out = nn.Sequential(
108
+ nn.Linear(dim_inner, dim, bias=False),
109
+ nn.Dropout(dropout)
110
+ )
111
+
112
+ def forward(self, x):
113
+ x = self.norm(x)
114
+
115
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
116
+
117
+ if exists(self.rotary_embed):
118
+ q = self.rotary_embed.rotate_queries_or_keys(q)
119
+ k = self.rotary_embed.rotate_queries_or_keys(k)
120
+
121
+ out = self.attend(q, k, v)
122
+
123
+ gates = self.to_gates(x)
124
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
125
+
126
+ out = rearrange(out, 'b h n d -> b n (h d)')
127
+ return self.to_out(out)
128
+
129
+
130
+ class LinearAttention(Module):
131
+ """
132
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
133
+ """
134
+
135
+ @beartype
136
+ def __init__(
137
+ self,
138
+ *,
139
+ dim,
140
+ dim_head=32,
141
+ heads=8,
142
+ scale=8,
143
+ flash=False,
144
+ dropout=0.,
145
+ sage_attention=False,
146
+ ):
147
+ super().__init__()
148
+ dim_inner = dim_head * heads
149
+ self.norm = RMSNorm(dim)
150
+
151
+ self.to_qkv = nn.Sequential(
152
+ nn.Linear(dim, dim_inner * 3, bias=False),
153
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
154
+ )
155
+
156
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
157
+
158
+ if sage_attention:
159
+ self.attend = AttendSage(
160
+ scale=scale,
161
+ dropout=dropout,
162
+ flash=flash
163
+ )
164
+ else:
165
+ self.attend = Attend(
166
+ scale=scale,
167
+ dropout=dropout,
168
+ flash=flash
169
+ )
170
+
171
+ self.to_out = nn.Sequential(
172
+ Rearrange('b h d n -> b n (h d)'),
173
+ nn.Linear(dim_inner, dim, bias=False)
174
+ )
175
+
176
+ def forward(
177
+ self,
178
+ x
179
+ ):
180
+ x = self.norm(x)
181
+
182
+ q, k, v = self.to_qkv(x)
183
+
184
+ q, k = map(l2norm, (q, k))
185
+ q = q * self.temperature.exp()
186
+
187
+ out = self.attend(q, k, v)
188
+
189
+ return self.to_out(out)
190
+
191
+ class Transformer(Module):
192
+ def __init__(
193
+ self,
194
+ *,
195
+ dim,
196
+ depth,
197
+ dim_head=64,
198
+ heads=8,
199
+ attn_dropout=0.,
200
+ ff_dropout=0.,
201
+ ff_mult=4,
202
+ norm_output=True,
203
+ rotary_embed=None,
204
+ flash_attn=True,
205
+ linear_attn=False,
206
+ sage_attention=False,
207
+ ):
208
+ super().__init__()
209
+ self.layers = ModuleList([])
210
+
211
+ for _ in range(depth):
212
+ if linear_attn:
213
+ attn = LinearAttention(
214
+ dim=dim,
215
+ dim_head=dim_head,
216
+ heads=heads,
217
+ dropout=attn_dropout,
218
+ flash=flash_attn,
219
+ sage_attention=sage_attention
220
+ )
221
+ else:
222
+ attn = Attention(
223
+ dim=dim,
224
+ dim_head=dim_head,
225
+ heads=heads,
226
+ dropout=attn_dropout,
227
+ rotary_embed=rotary_embed,
228
+ flash=flash_attn,
229
+ sage_attention=sage_attention
230
+ )
231
+
232
+ self.layers.append(ModuleList([
233
+ attn,
234
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
235
+ ]))
236
+
237
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
238
+
239
+ def forward(self, x):
240
+
241
+ for attn, ff in self.layers:
242
+ x = attn(x) + x
243
+ x = ff(x) + x
244
+
245
+ return self.norm(x)
246
+
247
+
248
+ # bandsplit module
249
+
250
+
251
+
252
+ class BandSplit(Module):
253
+ @beartype
254
+ def __init__(
255
+ self,
256
+ dim,
257
+ dim_inputs: Tuple[int, ...]
258
+ ):
259
+ super().__init__()
260
+ self.dim_inputs = dim_inputs
261
+ self.to_features = ModuleList([])
262
+
263
+ for dim_in in dim_inputs:
264
+ net = nn.Sequential(
265
+ RMSNorm(dim_in),
266
+ nn.Linear(dim_in, dim)
267
+ )
268
+
269
+ self.to_features.append(net)
270
+
271
+ def forward(self, x):
272
+
273
+ x = x.split(self.dim_inputs, dim=-1)
274
+
275
+ outs = []
276
+ for split_input, to_feature in zip(x, self.to_features):
277
+ split_output = to_feature(split_input)
278
+ outs.append(split_output)
279
+
280
+ x = torch.stack(outs, dim=-2)
281
+
282
+ return x
283
+
284
+ def MLP(
285
+ dim_in,
286
+ dim_out,
287
+ dim_hidden=None,
288
+ depth=1,
289
+ activation=nn.Tanh
290
+ ):
291
+ dim_hidden = default(dim_hidden, dim_in)
292
+
293
+ net = []
294
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
295
+
296
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
297
+ is_last = ind == (len(dims) - 2)
298
+
299
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
300
+
301
+ if is_last:
302
+ continue
303
+
304
+ net.append(activation())
305
+
306
+ return nn.Sequential(*net)
307
+
308
+ class MaskEstimator(Module):
309
+ @beartype
310
+ def __init__(
311
+ self,
312
+ dim,
313
+ dim_inputs: Tuple[int, ...],
314
+ depth,
315
+ mlp_expansion_factor=4
316
+ ):
317
+ super().__init__()
318
+ self.dim_inputs = dim_inputs
319
+ self.to_freqs = ModuleList([])
320
+ dim_hidden = dim * mlp_expansion_factor
321
+
322
+ for dim_in in dim_inputs:
323
+ net = []
324
+
325
+ mlp = nn.Sequential(
326
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
327
+ nn.GLU(dim=-1)
328
+ )
329
+
330
+ self.to_freqs.append(mlp)
331
+
332
+ self.layers = ModuleList([])
333
+
334
+ heads = 8
335
+ dim_head = 64
336
+
337
+ transformer_kwargs = dict(
338
+ dim=dim,
339
+ heads=heads,
340
+ dim_head=dim_head,
341
+ attn_dropout=0.,
342
+ ff_dropout=0.,
343
+ flash_attn=True,
344
+ norm_output=False,
345
+ sage_attention=False,
346
+ )
347
+
348
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
349
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
350
+
351
+ for _ in range(4):
352
+ tran_modules = []
353
+ tran_modules.append(
354
+ Transformer(depth=1, rotary_embed=time_rotary_embed, **transformer_kwargs)
355
+ )
356
+ tran_modules.append(
357
+ Transformer(depth=1, rotary_embed=freq_rotary_embed, **transformer_kwargs)
358
+ )
359
+ self.layers.append(nn.ModuleList(tran_modules))
360
+
361
+ self.norm = RMSNorm(dim)
362
+
363
+ def forward(self, x):
364
+
365
+ for i, transformer_block in enumerate(self.layers):
366
+
367
+
368
+ time_transformer, freq_transformer = transformer_block
369
+
370
+
371
+ x = rearrange(x, 'b t f d -> b f t d')
372
+ x, ps = pack([x], '* t d')
373
+
374
+
375
+ x = time_transformer(x)
376
+
377
+ x, = unpack(x, ps, '* t d')
378
+ x = rearrange(x, 'b f t d -> b t f d')
379
+ x, ps = pack([x], '* f d')
380
+
381
+
382
+ x = freq_transformer(x)
383
+
384
+ x, = unpack(x, ps, '* f d')
385
+
386
+ x = self.norm(x)
387
+
388
+ x = x.unbind(dim=-2)
389
+
390
+ outs = []
391
+
392
+ for band_features, mlp in zip(x, self.to_freqs):
393
+ freq_out = mlp(band_features)
394
+ outs.append(freq_out)
395
+
396
+ return torch.cat(outs, dim=-1)
397
+
398
+
399
+ # main class
400
+
401
+ DEFAULT_FREQS_PER_BANDS = (
402
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
403
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
404
+ 2, 2, 2, 2,
405
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
406
+ 12, 12, 12, 12, 12, 12, 12, 12,
407
+ 24, 24, 24, 24, 24, 24, 24, 24,
408
+ 48, 48, 48, 48, 48, 48, 48, 48,
409
+ 128, 129,
410
+ )
411
+
412
+ class BSRoformer(Module):
413
+
414
+ @beartype
415
+ def __init__(
416
+ self,
417
+ dim,
418
+ *,
419
+ depth,
420
+ stereo=False,
421
+ num_stems=1,
422
+ time_transformer_depth=2,
423
+ freq_transformer_depth=2,
424
+ linear_transformer_depth=0,
425
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
426
+ # in the paper, they divide into ~60 bands, test with 1 for starters
427
+ dim_head=64,
428
+ heads=8,
429
+ attn_dropout=0.,
430
+ ff_dropout=0.,
431
+ flash_attn=True,
432
+ dim_freqs_in=1025,
433
+ stft_n_fft=2048,
434
+ stft_hop_length=512,
435
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
436
+ stft_win_length=2048,
437
+ stft_normalized=False,
438
+ stft_window_fn: Optional[Callable] = None,
439
+ mask_estimator_depth=2,
440
+ multi_stft_resolution_loss_weight=1.,
441
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
442
+ multi_stft_hop_size=147,
443
+ multi_stft_normalized=False,
444
+ multi_stft_window_fn: Callable = torch.hann_window,
445
+ mlp_expansion_factor=4,
446
+ use_torch_checkpoint=False,
447
+ skip_connection=False,
448
+ sage_attention=False,
449
+ ):
450
+ super().__init__()
451
+
452
+ self.stereo = stereo
453
+ self.audio_channels = 2 if stereo else 1
454
+ self.num_stems = num_stems
455
+ self.use_torch_checkpoint = use_torch_checkpoint
456
+ self.skip_connection = skip_connection
457
+
458
+ self.layers = ModuleList([])
459
+
460
+ if sage_attention:
461
+ print("Use Sage Attention")
462
+
463
+ transformer_kwargs = dict(
464
+ dim=dim,
465
+ heads=heads,
466
+ dim_head=dim_head,
467
+ attn_dropout=attn_dropout,
468
+ ff_dropout=ff_dropout,
469
+ flash_attn=flash_attn,
470
+ norm_output=False,
471
+ sage_attention=sage_attention,
472
+ )
473
+
474
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
475
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
476
+
477
+ for _ in range(depth):
478
+ tran_modules = []
479
+ tran_modules.append(
480
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
481
+ )
482
+ tran_modules.append(
483
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
484
+ )
485
+ self.layers.append(nn.ModuleList(tran_modules))
486
+
487
+ self.final_norm = RMSNorm(dim)
488
+
489
+ self.stft_kwargs = dict(
490
+ n_fft=stft_n_fft,
491
+ hop_length=stft_hop_length,
492
+ win_length=stft_win_length,
493
+ normalized=stft_normalized
494
+ )
495
+
496
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
497
+
498
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
499
+
500
+ assert len(freqs_per_bands) > 1
501
+ assert sum(
502
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
503
+
504
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
505
+
506
+ self.band_split = BandSplit(
507
+ dim=dim,
508
+ dim_inputs=freqs_per_bands_with_complex
509
+ )
510
+
511
+ self.mask_estimators = nn.ModuleList([])
512
+
513
+ for _ in range(num_stems):
514
+ mask_estimator = MaskEstimator(
515
+ dim=dim,
516
+ dim_inputs=freqs_per_bands_with_complex,
517
+ depth=mask_estimator_depth,
518
+ mlp_expansion_factor=mlp_expansion_factor,
519
+ )
520
+
521
+ self.mask_estimators.append(mask_estimator)
522
+
523
+ # for the multi-resolution stft loss
524
+
525
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
526
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
527
+ self.multi_stft_n_fft = stft_n_fft
528
+ self.multi_stft_window_fn = multi_stft_window_fn
529
+
530
+ self.multi_stft_kwargs = dict(
531
+ hop_length=multi_stft_hop_size,
532
+ normalized=multi_stft_normalized
533
+ )
534
+
535
+ def forward(
536
+ self,
537
+ raw_audio,
538
+ target=None,
539
+ return_loss_breakdown=False
540
+ ):
541
+ """
542
+ einops
543
+
544
+ b - batch
545
+ f - freq
546
+ t - time
547
+ s - audio channel (1 for mono, 2 for stereo)
548
+ n - number of 'stems'
549
+ c - complex (2)
550
+ d - feature dimension
551
+ """
552
+
553
+ device = raw_audio.device
554
+
555
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
556
+ x_is_mps = True if device.type == "mps" else False
557
+
558
+ if raw_audio.ndim == 2:
559
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
560
+
561
+ channels = raw_audio.shape[1]
562
+ assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
563
+
564
+ # to stft
565
+
566
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
567
+
568
+ stft_window = self.stft_window_fn(device=device)
569
+
570
+ # RuntimeError: FFT operations are only supported on MacOS 14+
571
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
572
+ try:
573
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
574
+ except:
575
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs,
576
+ window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(
577
+ device)
578
+ stft_repr = torch.view_as_real(stft_repr)
579
+
580
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
581
+
582
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
583
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
584
+
585
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
586
+
587
+
588
+ x = self.band_split(x)
589
+
590
+ # axial / hierarchical attention
591
+
592
+ for i, transformer_block in enumerate(self.layers):
593
+
594
+
595
+ time_transformer, freq_transformer = transformer_block
596
+
597
+
598
+ x = rearrange(x, 'b t f d -> b f t d')
599
+ x, ps = pack([x], '* t d')
600
+
601
+
602
+ x = time_transformer(x)
603
+
604
+ x, = unpack(x, ps, '* t d')
605
+ x = rearrange(x, 'b f t d -> b t f d')
606
+ x, ps = pack([x], '* f d')
607
+
608
+
609
+ x = freq_transformer(x)
610
+
611
+ x, = unpack(x, ps, '* f d')
612
+
613
+
614
+ x = self.final_norm(x)
615
+
616
+ num_stems = len(self.mask_estimators)
617
+
618
+
619
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
620
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
621
+
622
+ # modulate frequency representation
623
+
624
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
625
+
626
+ stft_repr = torch.view_as_complex(stft_repr)
627
+ mask = torch.view_as_complex(mask)
628
+
629
+ stft_repr = stft_repr * mask
630
+
631
+ # istft
632
+
633
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
634
+
635
+ try:
636
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
637
+ except:
638
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
639
+
640
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
641
+
642
+ if num_stems == 1:
643
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
644
+
645
+ # if a target is passed in, calculate loss for learning
646
+
647
+ if not exists(target):
648
+ return recon_audio
649
+
650
+ if self.num_stems > 1:
651
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
652
+
653
+ if target.ndim == 2:
654
+ target = rearrange(target, '... t -> ... 1 t')
655
+
656
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
657
+
658
+ loss = F.l1_loss(recon_audio, target)
659
+
660
+ multi_stft_resolution_loss = 0.
661
+
662
+ for window_size in self.multi_stft_resolutions_window_sizes:
663
+ res_stft_kwargs = dict(
664
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
665
+ win_length=window_size,
666
+ return_complex=True,
667
+ window=self.multi_stft_window_fn(window_size, device=device),
668
+ **self.multi_stft_kwargs,
669
+ )
670
+
671
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
672
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
673
+
674
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
675
+
676
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
677
+
678
+ total_loss = loss + weighted_multi_resolution_loss
679
+
680
+ if not return_loss_breakdown:
681
+ return total_loss
682
+
683
+ return total_loss, (loss, multi_stft_resolution_loss)