ArianatorQualquer commited on
Commit
c0be788
·
verified ·
1 Parent(s): a8a84ee

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +224 -0
utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import time
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import yaml
9
+ from ml_collections import ConfigDict
10
+ from omegaconf import OmegaConf
11
+ from tqdm import tqdm
12
+
13
+ def get_model_from_config(model_type, config_path):
14
+ with open(config_path) as f:
15
+ if model_type == 'htdemucs':
16
+ config = OmegaConf.load(config_path)
17
+ else:
18
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
19
+
20
+ if model_type == 'mdx23c':
21
+ from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
22
+ model = TFC_TDF_net(config)
23
+ elif model_type == 'htdemucs':
24
+ from models.demucs4ht import get_model
25
+ model = get_model(config)
26
+ elif model_type == 'segm_models':
27
+ from models.segm_models import Segm_Models_Net
28
+ model = Segm_Models_Net(config)
29
+ elif model_type == 'torchseg':
30
+ from models.torchseg_models import Torchseg_Net
31
+ model = Torchseg_Net(config)
32
+ elif model_type == 'mel_band_roformer':
33
+ from models.bs_roformer import MelBandRoformer
34
+ model = MelBandRoformer(
35
+ **dict(config.model)
36
+ )
37
+ elif model_type == 'bs_roformer':
38
+ from models.bs_roformer import BSRoformer
39
+ model = BSRoformer(
40
+ **dict(config.model)
41
+ )
42
+ elif model_type == 'swin_upernet':
43
+ from models.upernet_swin_transformers import Swin_UperNet_Model
44
+ model = Swin_UperNet_Model(config)
45
+ elif model_type == 'bandit':
46
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
47
+ model = MultiMaskMultiSourceBandSplitRNNSimple(
48
+ **config.model
49
+ )
50
+ elif model_type == 'bandit_v2':
51
+ from models.bandit_v2.bandit import Bandit
52
+ model = Bandit(
53
+ **config.kwargs
54
+ )
55
+ elif model_type == 'scnet_unofficial':
56
+ from models.scnet_unofficial import SCNet
57
+ model = SCNet(
58
+ **config.model
59
+ )
60
+ elif model_type == 'scnet':
61
+ from models.scnet import SCNet
62
+ model = SCNet(
63
+ **config.model
64
+ )
65
+ else:
66
+ print('Unknown model: {}'.format(model_type))
67
+ model = None
68
+
69
+ return model, config
70
+
71
+ def _getWindowingArray(window_size, fade_size):
72
+ fadein = torch.linspace(0, 1, fade_size)
73
+ fadeout = torch.linspace(1, 0, fade_size)
74
+ window = torch.ones(window_size)
75
+ window[-fade_size:] *= fadeout
76
+ window[:fade_size] *= fadein
77
+ return window
78
+
79
+
80
+ def demix_track(config, model, mix, device, pbar=False):
81
+ C = config.audio.chunk_size
82
+ N = config.inference.num_overlap
83
+ fade_size = C // 10
84
+ step = int(C // N)
85
+ border = C - step
86
+ batch_size = config.inference.batch_size
87
+
88
+ length_init = mix.shape[-1]
89
+
90
+ # Do pad from the beginning and end to account floating window results better
91
+ if length_init > 2 * border and (border > 0):
92
+ mix = nn.functional.pad(mix, (border, border), mode='reflect')
93
+
94
+ # windowingArray crossfades at segment boundaries to mitigate clicking artifacts
95
+ windowingArray = _getWindowingArray(C, fade_size)
96
+
97
+ with torch.cuda.amp.autocast(enabled=config.training.use_amp):
98
+ with torch.inference_mode():
99
+ if config.training.target_instrument is not None:
100
+ req_shape = (1, ) + tuple(mix.shape)
101
+ else:
102
+ req_shape = (len(config.training.instruments),) + tuple(mix.shape)
103
+
104
+ result = torch.zeros(req_shape, dtype=torch.float32)
105
+ counter = torch.zeros(req_shape, dtype=torch.float32)
106
+ i = 0
107
+ batch_data = []
108
+ batch_locations = []
109
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
110
+
111
+ while i < mix.shape[1]:
112
+ # print(i, i + C, mix.shape[1])
113
+ part = mix[:, i:i + C].to(device)
114
+ length = part.shape[-1]
115
+ if length < C:
116
+ if length > C // 2 + 1:
117
+ part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
118
+ else:
119
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
120
+ batch_data.append(part)
121
+ batch_locations.append((i, length))
122
+ i += step
123
+
124
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
125
+ arr = torch.stack(batch_data, dim=0)
126
+ x = model(arr)
127
+
128
+ window = windowingArray
129
+ if i - step == 0: # First audio chunk, no fadein
130
+ window[:fade_size] = 1
131
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
132
+ window[-fade_size:] = 1
133
+
134
+ for j in range(len(batch_locations)):
135
+ start, l = batch_locations[j]
136
+ result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
137
+ counter[..., start:start+l] += window[..., :l]
138
+
139
+ batch_data = []
140
+ batch_locations = []
141
+
142
+ if progress_bar:
143
+ progress_bar.update(step)
144
+
145
+ if progress_bar:
146
+ progress_bar.close()
147
+
148
+ estimated_sources = result / counter
149
+ estimated_sources = estimated_sources.cpu().numpy()
150
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
151
+
152
+ if length_init > 2 * border and (border > 0):
153
+ # Remove pad
154
+ estimated_sources = estimated_sources[..., border:-border]
155
+
156
+ if config.training.target_instrument is None:
157
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
158
+ else:
159
+ return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
160
+
161
+
162
+ def demix_track_demucs(config, model, mix, device, pbar=False):
163
+ S = len(config.training.instruments)
164
+ C = config.training.samplerate * config.training.segment
165
+ N = config.inference.num_overlap
166
+ batch_size = config.inference.batch_size
167
+ step = C // N
168
+ # print(S, C, N, step, mix.shape, mix.device)
169
+
170
+ with torch.cuda.amp.autocast(enabled=config.training.use_amp):
171
+ with torch.inference_mode():
172
+ req_shape = (S, ) + tuple(mix.shape)
173
+ result = torch.zeros(req_shape, dtype=torch.float32)
174
+ counter = torch.zeros(req_shape, dtype=torch.float32)
175
+ i = 0
176
+ batch_data = []
177
+ batch_locations = []
178
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
179
+
180
+ while i < mix.shape[1]:
181
+ # print(i, i + C, mix.shape[1])
182
+ part = mix[:, i:i + C].to(device)
183
+ length = part.shape[-1]
184
+ if length < C:
185
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
186
+ batch_data.append(part)
187
+ batch_locations.append((i, length))
188
+ i += step
189
+
190
+
191
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
192
+ arr = torch.stack(batch_data, dim=0)
193
+ x = model(arr)
194
+ for j in range(len(batch_locations)):
195
+ start, l = batch_locations[j]
196
+ result[..., start:start+l] += x[j][..., :l].cpu()
197
+ counter[..., start:start+l] += 1.
198
+ batch_data = []
199
+ batch_locations = []
200
+
201
+ if progress_bar:
202
+ progress_bar.update(step)
203
+
204
+ if progress_bar:
205
+ progress_bar.close()
206
+
207
+ estimated_sources = result / counter
208
+ estimated_sources = estimated_sources.cpu().numpy()
209
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
210
+
211
+ if S > 1:
212
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
213
+ else:
214
+ return estimated_sources
215
+
216
+
217
+ def sdr(references, estimates):
218
+ # compute SDR for one song
219
+ delta = 1e-7 # avoid numerical errors
220
+ num = np.sum(np.square(references), axis=(1, 2))
221
+ den = np.sum(np.square(references - estimates), axis=(1, 2))
222
+ num += delta
223
+ den += delta
224
+ return 10 * np.log10(num / den)