HiDolen commited on
Commit
be11548
·
verified ·
1 Parent(s): 427df84

Upload 2 files

Browse files
configuration_bs_roformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BS-RoFormer model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ DEFAULT_FREQS_PER_BANDS = (
7
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
8
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
9
+ 12, 12, 12, 12, 12, 12, 12, 12,
10
+ 24, 24, 24, 24, 24, 24, 24, 24,
11
+ 48, 48, 48, 48, 48, 48, 48, 48,
12
+ 128, 129,
13
+ )
14
+
15
+
16
+ class BSRoformerConfig(PretrainedConfig):
17
+
18
+ model_type = "bs_roformer"
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size=384,
23
+ depth=6,
24
+ num_input_channel=1,
25
+ num_stems=1,
26
+ time_transformer_depth=2,
27
+ freq_transformer_depth=2,
28
+ freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
29
+ attention_dropout=0.0,
30
+ num_attention_heads=8,
31
+ num_key_value_heads=8,
32
+ intermediate_size=384 * 4,
33
+ #
34
+ stft_n_fft=2048,
35
+ stft_hop_length=512,
36
+ stft_win_length=2048,
37
+ mask_estimator_depth=2,
38
+ multi_stft_loss_weight=1.0,
39
+ multi_stft_loss_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
40
+ multi_stft_loss_hop_size=147,
41
+ rms_norm_eps=1e-6,
42
+ rope_theta=10000.0,
43
+ #
44
+ initializer_range=0.02,
45
+ register_token_num=4,
46
+ **kwargs,
47
+ ):
48
+ self.hidden_size = hidden_size
49
+ self.depth = depth
50
+ self.num_input_channel = num_input_channel
51
+ self.num_stems = num_stems
52
+ self.time_transformer_depth = time_transformer_depth
53
+ self.freq_transformer_depth = freq_transformer_depth
54
+ self.freqs_per_bands = freqs_per_bands
55
+ self.attention_dropout = attention_dropout
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_key_value_heads = num_key_value_heads
58
+ self.intermediate_size = intermediate_size
59
+
60
+ self.stft_n_fft = stft_n_fft
61
+ self.stft_hop_length = stft_hop_length
62
+ self.stft_win_length = stft_win_length
63
+
64
+ self.mask_estimator_depth = mask_estimator_depth
65
+ self.multi_stft_loss_weight = multi_stft_loss_weight
66
+ self.multi_stft_loss_window_sizes = multi_stft_loss_window_sizes
67
+ self.multi_stft_loss_hop_size = multi_stft_loss_hop_size
68
+ self.rms_norm_eps = rms_norm_eps
69
+ self.rope_theta = rope_theta
70
+
71
+ self.initializer_range = initializer_range
72
+ self.register_token_num = register_token_num
73
+
74
+ super().__init__(**kwargs)
modeling_bs_roformer.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
10
+
11
+ from .configuration_bs_roformer import BSRoformerConfig
12
+
13
+
14
+
15
+ def rotate_half(x):
16
+ x1 = x[..., : x.shape[-1] // 2]
17
+ x2 = x[..., x.shape[-1] // 2 :]
18
+ return torch.cat((-x2, x1), dim=-1)
19
+
20
+
21
+ def apply_rotary_pos_emb(q, k, cos, sin):
22
+ q_embed = (q * cos) + (rotate_half(q) * sin)
23
+ k_embed = (k * cos) + (rotate_half(k) * sin)
24
+ return q_embed, k_embed
25
+
26
+
27
+ class RotaryEmbedding(nn.Module):
28
+ def __init__(self, config: BSRoformerConfig):
29
+ super().__init__()
30
+ self.head_dim = config.hidden_size // config.num_attention_heads
31
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
32
+ self.register_buffer("inv_freq", inv_freq)
33
+
34
+ def forward(self, x, position_ids):
35
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
36
+ position_ids_expanded = position_ids[:, None, :].float()
37
+
38
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
39
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
40
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
41
+ emb = torch.cat((freqs, freqs), dim=-1)
42
+ cos = emb.cos()
43
+ sin = emb.sin()
44
+
45
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
46
+
47
+
48
+ class BSRoformerMLP(nn.Module):
49
+ def __init__(self, config: BSRoformerConfig):
50
+ super().__init__()
51
+ self.config = config
52
+ self.hidden_size = config.hidden_size
53
+ self.intermediate_size = config.intermediate_size
54
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
55
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
56
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
57
+ self.act_fn = ACT2FN["gelu"]
58
+
59
+ def forward(self, x):
60
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
61
+ return down_proj
62
+
63
+
64
+ class BSRoformerAttention(nn.Module):
65
+ def __init__(self, config: BSRoformerConfig):
66
+ super().__init__()
67
+ self.is_causal = False
68
+ self.config = config
69
+
70
+ self.head_dim = config.hidden_size // config.num_attention_heads
71
+ self.scaling = self.head_dim**-0.5
72
+ self.attention_dropout = config.attention_dropout
73
+
74
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
75
+
76
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
77
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
78
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
79
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
80
+
81
+ def forward(
82
+ self,
83
+ hidden_states,
84
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
85
+ attention_mask=None,
86
+ ):
87
+ input_shape = hidden_states.size()[:-1]
88
+ hidden_shape = (*input_shape, -1, self.head_dim)
89
+
90
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
91
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
92
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
93
+
94
+ cos, sin = position_embeddings
95
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
96
+
97
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
98
+
99
+ attn_output, attn_weights = attention_interface(
100
+ self,
101
+ query_states,
102
+ key_states,
103
+ value_states,
104
+ attention_mask,
105
+ dropout=0.0 if not self.training else self.attention_dropout,
106
+ scaling=self.scaling,
107
+ )
108
+
109
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
110
+ attn_output = self.o_proj(attn_output)
111
+
112
+ return attn_output, attn_weights
113
+
114
+
115
+ class BSRoformerLayer(nn.Module):
116
+ def __init__(self, config: BSRoformerConfig):
117
+ super().__init__()
118
+ self.self_attn = BSRoformerAttention(config)
119
+ self.mlp = BSRoformerMLP(config)
120
+
121
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
122
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states,
127
+ position_embeddings,
128
+ attention_mask,
129
+ ):
130
+ residual = hidden_states
131
+ hidden_states = self.input_layernorm(hidden_states)
132
+ hidden_states, _ = self.self_attn(
133
+ hidden_states,
134
+ position_embeddings,
135
+ attention_mask,
136
+ )
137
+ hidden_states = hidden_states + residual
138
+
139
+ residual = hidden_states
140
+ hidden_states = self.post_attention_layernorm(hidden_states)
141
+ hidden_states = self.mlp(hidden_states)
142
+ hidden_states = hidden_states + residual
143
+
144
+ return hidden_states
145
+
146
+
147
+ class BSRoformerAxialTransformer(nn.Module):
148
+ def __init__(
149
+ self,
150
+ config: BSRoformerConfig,
151
+ transformer_depth: int,
152
+ is_time_transformer: bool,
153
+ ):
154
+ super().__init__()
155
+ self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)])
156
+ self.is_time_transformer = is_time_transformer
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states,
161
+ position_embeddings,
162
+ attention_mask,
163
+ ):
164
+ if self.is_time_transformer:
165
+ hidden_states = rearrange(hidden_states, 'b t f d -> b f t d')
166
+
167
+ b, seq_len_1, seq_len_2, d = hidden_states.shape
168
+ hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d')
169
+
170
+ for layer in self.layers:
171
+ hidden_states = layer(
172
+ hidden_states,
173
+ position_embeddings,
174
+ attention_mask,
175
+ )
176
+
177
+ hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b)
178
+
179
+ if self.is_time_transformer:
180
+ hidden_states = rearrange(hidden_states, 'b f t d -> b t f d')
181
+
182
+ return hidden_states
183
+
184
+
185
+ class BandSplit(nn.Module):
186
+ def __init__(self, config: BSRoformerConfig):
187
+ super().__init__()
188
+ self.dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)
189
+ self.to_features = nn.ModuleList(
190
+ [
191
+ nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.hidden_size))
192
+ for dim_in in self.dim_inputs
193
+ ]
194
+ )
195
+
196
+ def forward(self, x):
197
+ x_split = x.split(self.dim_inputs, dim=-1)
198
+ outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)]
199
+ return torch.stack(outs, dim=-2)
200
+
201
+
202
+ class MaskEstimator(nn.Module):
203
+ def __init__(self, config: BSRoformerConfig):
204
+ super().__init__()
205
+ dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)
206
+ self.to_freq_mlps = nn.ModuleList([nn.Linear(config.hidden_size, dim_in) for dim_in in dim_inputs])
207
+
208
+ def forward(self, x):
209
+ x_unbind = x.unbind(dim=-2)
210
+ outs = [mlp(band_features) for band_features, mlp in zip(x_unbind, self.to_freq_mlps)]
211
+ return torch.cat(outs, dim=-1)
212
+
213
+
214
+ class BSRoformerPreTrainedModel(PreTrainedModel):
215
+ config_class = BSRoformerConfig
216
+ base_model_prefix = "model"
217
+ supports_gradient_checkpointing = True
218
+ _no_split_modules = ["BSRoformerLayer"]
219
+
220
+
221
+ class BSRoformerModel(BSRoformerPreTrainedModel):
222
+ def __init__(self, config: BSRoformerConfig):
223
+ super().__init__(config)
224
+ self.config = config
225
+ self.band_split = BandSplit(config)
226
+ self.layers = nn.ModuleList(
227
+ nn.ModuleList(
228
+ [
229
+ BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True),
230
+ BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False),
231
+ ]
232
+ )
233
+ for _ in range(config.depth)
234
+ )
235
+ self.rotary_emb = RotaryEmbedding(config)
236
+ self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+
238
+ rn = config.register_token_num
239
+ self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size)))
240
+
241
+ self.post_init()
242
+
243
+ def forward(
244
+ self,
245
+ x,
246
+ position_ids=None,
247
+ ):
248
+ hidden_states = self.band_split(x)
249
+
250
+ b, t, n, h = hidden_states.shape
251
+
252
+ if position_ids is None:
253
+ position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0)
254
+ pos_embeds = self.rotary_emb(hidden_states, position_ids)
255
+ pos_embeds_for_freq = self.rotary_emb(
256
+ hidden_states,
257
+ torch.arange(n, device=hidden_states.device).unsqueeze(0),
258
+ )
259
+
260
+ rn = self.config.register_token_num
261
+ hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn))
262
+ hidden_states[:, t:, n:, :] = self.register_tokens
263
+
264
+ def pad_rope(cos, sin):
265
+ cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0)
266
+ sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0)
267
+ return cos_padded, sin_padded
268
+
269
+ pos_embeds = pad_rope(*pos_embeds)
270
+ pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq)
271
+
272
+ for time_transformer, freq_transformer in self.layers:
273
+ hidden_states = time_transformer(
274
+ hidden_states,
275
+ position_embeddings=pos_embeds,
276
+ attention_mask=None,
277
+ )
278
+ hidden_states = freq_transformer(
279
+ hidden_states,
280
+ position_embeddings=pos_embeds_for_freq,
281
+ attention_mask=None,
282
+ )
283
+
284
+ hidden_states = hidden_states[:, :t, :n, :]
285
+
286
+ return self.final_norm(hidden_states)
287
+
288
+
289
+ class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel):
290
+ def __init__(self, config: BSRoformerConfig):
291
+ super().__init__(config)
292
+ self.config = config
293
+ self.model = BSRoformerModel(config)
294
+ self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)])
295
+
296
+ self.stft_kwargs = dict(
297
+ n_fft=config.stft_n_fft,
298
+ hop_length=config.stft_hop_length,
299
+ win_length=config.stft_win_length,
300
+ normalized=False,
301
+ )
302
+ self.register_buffer("stft_window", torch.hann_window(config.stft_win_length), persistent=False)
303
+
304
+ freqs = config.stft_n_fft // 2 + 1
305
+ assert sum(config.freqs_per_bands) == freqs, f"Sum of freqs_per_bands must be {freqs}"
306
+ self.wave_channels = config.num_input_channel
307
+
308
+ def forward(
309
+ self,
310
+ raw_audio: torch.Tensor,
311
+ target: Optional[torch.Tensor] = None,
312
+ ):
313
+ device = raw_audio.device
314
+
315
+ with torch.autocast(device_type=device.type, enabled=False):
316
+ b, c, t = raw_audio.shape
317
+ raw_audio_packed = rearrange(raw_audio, "b c t -> (b c) t")
318
+ stft_repr = torch.stft(
319
+ raw_audio_packed,
320
+ **self.stft_kwargs,
321
+ window=self.stft_window,
322
+ return_complex=True,
323
+ )
324
+ stft_repr = torch.view_as_real(stft_repr)
325
+ stft_repr = rearrange(stft_repr, "(b c) f t T -> b c f t T", c=c)
326
+ stft_repr_merged = rearrange(stft_repr, "b c f t T -> b t (f c T)")
327
+
328
+ hidden_states = self.model(stft_repr_merged)
329
+
330
+ mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1)
331
+ mask = rearrange(mask, "b n t (f c T) -> b n c f t T", T=2, c=c)
332
+ mask = mask.to(dtype=torch.float32)
333
+
334
+ with torch.autocast(device_type=device.type, enabled=False):
335
+ stft_repr_expanded = rearrange(stft_repr, "b c f t T -> b 1 c f t T")
336
+ stft_repr_complex = torch.view_as_complex(stft_repr_expanded)
337
+ mask_complex = torch.view_as_complex(mask)
338
+ masked_stft = stft_repr_complex * mask_complex
339
+
340
+ masked_stft = rearrange(masked_stft, "b n c f t -> (b n c) f t")
341
+ recon_audio = torch.istft(
342
+ masked_stft,
343
+ **self.stft_kwargs,
344
+ window=self.stft_window,
345
+ return_complex=False,
346
+ length=raw_audio.shape[-1],
347
+ )
348
+ recon_audio = rearrange(recon_audio, "(b n c) t -> b n c t", c=self.wave_channels, n=self.config.num_stems)
349
+
350
+ if target is None:
351
+ return recon_audio
352
+
353
+ target = target[..., : recon_audio.shape[-1]]
354
+ loss = F.l1_loss(recon_audio, target)
355
+ return loss
356
+
357
+ def separate(
358
+ self,
359
+ mixed_wave: torch.Tensor,
360
+ chunk_size: int = 44100 * 8,
361
+ overlap_size: int = 44100 * 4,
362
+ batch_size: int = 16,
363
+ gap_size: int = 44100 * 1,
364
+ verbose: bool = True,
365
+ ):
366
+ """
367
+ Separates a full audio waveform into its constituent stems using a sliding window approach.
368
+
369
+ Args:
370
+ mixed_wave (`torch.Tensor` of shape `(channels, time)`):
371
+ The raw audio waveform of the mixture.
372
+ chunk_size (`int`, *optional*, defaults to `352800` (8 seconds at 44.1kHz)):
373
+ The size of each audio chunk for processing.
374
+ overlap_size (`int`, *optional*, defaults to `176400` (4 seconds at 44.1kHz)):
375
+ The size of the overlap between consecutive chunks.
376
+ batch_size (`int`, *optional*, defaults to `16`):
377
+ The number of chunks to process in a single batch.
378
+ gap_size (`int`, *optional*, defaults to `44100` (1 second at 44.1kHz)):
379
+ The size of the gap for the fade-in/fade-out window.
380
+ verbose (`bool`, *optional*, defaults to `True`):
381
+ Whether to print progress information during processing.
382
+
383
+ Returns:
384
+ torch.Tensor (`torch.Tensor` of shape `(num_stems, channels, time)`):
385
+ The separated audio waveforms.
386
+ """
387
+ if mixed_wave.dim() != 2:
388
+ raise ValueError("Input `mixed_wave` must be a 2D tensor of shape (channels, time)")
389
+
390
+ device = mixed_wave.device
391
+
392
+ # Fade-in/fade-out window
393
+ fade_size = chunk_size // 10
394
+ window = torch.ones(chunk_size - 2 * gap_size, device=device)
395
+ window[:fade_size] = torch.linspace(0, 1, fade_size, device=device)
396
+ window[-fade_size:] = torch.linspace(1, 0, fade_size, device=device)
397
+ window = F.pad(window, (gap_size, gap_size), value=0.0)
398
+
399
+ with torch.inference_mode():
400
+ wave_length = mixed_wave.shape[-1]
401
+
402
+ if wave_length <= chunk_size:
403
+ num_chunks = 1
404
+ else:
405
+ num_chunks = math.ceil((wave_length - chunk_size) / overlap_size) + 1
406
+
407
+ required_length = (num_chunks - 1) * overlap_size + chunk_size
408
+ padded_wave = F.pad(
409
+ mixed_wave,
410
+ (0, required_length - wave_length),
411
+ mode="constant",
412
+ )
413
+
414
+ unfolded_chunks = padded_wave.unfold(
415
+ dimension=-1,
416
+ size=chunk_size,
417
+ step=overlap_size,
418
+ ) # (C, num_chunks, chunk_size)
419
+ batch = unfolded_chunks.permute(1, 0, 2) # (num_chunks, C, chunk_size)
420
+
421
+ if verbose:
422
+ print(f"Input wave shape: {mixed_wave.shape}")
423
+ print(f"Padded wave shape: {padded_wave.shape}")
424
+ print(f"Number of chunks: {num_chunks}")
425
+ output_chunks = []
426
+ for i in range(0, num_chunks, batch_size):
427
+ chunk_batch = batch[i : i + batch_size]
428
+ output_chunk = self(chunk_batch) # Call forward method
429
+ output_chunks.append(output_chunk)
430
+ if verbose:
431
+ print(f"Processed chunks {i} to {i + chunk_batch.shape[0]}")
432
+ batch_output = torch.cat(output_chunks, dim=0) # (num_chunks, num_stems, C, chunk_size)
433
+
434
+ _, num_stems, C, _ = batch_output.shape
435
+ batch_output = batch_output.view(num_chunks, -1, chunk_size).permute(1, 0, 2) # (num_stems * C, num_chunks, chunk_size)
436
+ batch_output = batch_output * window
437
+ output_result_buffer = F.fold(
438
+ batch_output.permute(0, 2, 1),
439
+ output_size=(1, required_length),
440
+ kernel_size=(1, chunk_size),
441
+ stride=(1, overlap_size),
442
+ ) # (num_stems * C, 1, 1, required_length)
443
+
444
+ window_for_fold = window.expand(1, 1, -1).repeat(1, num_chunks, 1)
445
+ weighted_sum_counter = F.fold(
446
+ window_for_fold.permute(0, 2, 1),
447
+ output_size=(1, required_length),
448
+ kernel_size=(1, chunk_size),
449
+ stride=(1, overlap_size),
450
+ ) # (1, 1, 1, required_length)
451
+
452
+ output_result_buffer = output_result_buffer.view(num_stems, C, -1) # (num_stems, C, required_length)
453
+ weighted_sum_counter = weighted_sum_counter.view(1, 1, -1)
454
+ weighted_sum_counter.clamp_min_(1e-8)
455
+
456
+ final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length]
457
+
458
+ return final_output