inoryQwQ commited on
Commit
ab5bd26
·
verified ·
1 Parent(s): 3cd0042

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. .gitignore +2 -0
  2. MelBandRoformer.py +293 -0
  3. config.json +0 -0
  4. gradio_app.py +87 -0
  5. main.py +119 -0
  6. mel_band_roformer.axmodel +3 -0
  7. requirements.txt +9 -0
  8. screenshot.png +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.wav
MelBandRoformer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import axengine as axe
2
+ import numpy as np
3
+ import soundfile as sf
4
+ import librosa
5
+ import torch
6
+ import tqdm
7
+ from librosa import filters
8
+ from einops import rearrange, reduce, repeat
9
+ from typing import Union
10
+
11
+
12
+ class MelBandRoformer:
13
+ def __init__(
14
+ self,
15
+ model_path,
16
+ *,
17
+ stft_n_fft=2048,
18
+ stft_win_length=2048,
19
+ stft_hop_length=441,
20
+ stft_normalized=False,
21
+ sample_rate=44100,
22
+ num_bands=60,
23
+ stereo=True
24
+ ):
25
+ self.stft_kwargs = dict(
26
+ n_fft=stft_n_fft,
27
+ hop_length=stft_hop_length,
28
+ win_length=stft_win_length,
29
+ normalized=stft_normalized,
30
+ )
31
+ self.sample_rate = sample_rate
32
+ self.num_bands = num_bands
33
+ self.stereo = stereo
34
+ self.num_channels = 2 if stereo else 1
35
+
36
+ self.freq_indices, _, _, self.num_bands_per_freq = self.calc_freq_indices()
37
+
38
+ self.model = axe.InferenceSession(
39
+ model_path,
40
+ providers=["AxEngineExecutionProvider", "AXCLRTExecutionProvider"],
41
+ )
42
+
43
+ def calc_freq_indices(self):
44
+ freqs = torch.stft(
45
+ torch.randn(1, 4096),
46
+ **self.stft_kwargs,
47
+ window=torch.ones(self.stft_kwargs["n_fft"]),
48
+ return_complex=True
49
+ ).shape[1]
50
+
51
+ # create mel filter bank
52
+ # with librosa.filters.mel as in section 2 of paper
53
+
54
+ mel_filter_bank_numpy = filters.mel(
55
+ sr=self.sample_rate, n_fft=self.stft_kwargs["n_fft"], n_mels=self.num_bands
56
+ )
57
+
58
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
59
+
60
+ # for some reason, it doesn't include the first freq? just force a value for now
61
+
62
+ mel_filter_bank[0][0] = 1.0
63
+
64
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
65
+ # so let's force a positive value
66
+
67
+ mel_filter_bank[-1, -1] = 1.0
68
+
69
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
70
+
71
+ freqs_per_band = mel_filter_bank > 0
72
+ assert freqs_per_band.any(
73
+ dim=0
74
+ ).all(), "all frequencies need to be covered by all bands for now"
75
+
76
+ repeated_freq_indices = repeat(
77
+ torch.arange(freqs), "f -> b f", b=self.num_bands
78
+ )
79
+ freq_indices = repeated_freq_indices[freqs_per_band]
80
+
81
+ if self.stereo:
82
+ freq_indices = repeat(freq_indices, "f -> f s", s=2)
83
+ freq_indices = freq_indices * 2 + torch.arange(2)
84
+ freq_indices = rearrange(freq_indices, "f s -> (f s)")
85
+
86
+ num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
87
+ num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
88
+
89
+ return freq_indices, freqs_per_band, num_freqs_per_band, num_bands_per_freq
90
+
91
+ def infer(
92
+ self, audio: Union[str, np.ndarray], chunk_size=88200, overlap=0.25, num_stems=4
93
+ ):
94
+ if isinstance(audio, str):
95
+ wav, _ = librosa.load(audio, sr=self.sample_rate, mono=not self.stereo)
96
+ else:
97
+ wav = audio
98
+
99
+ if self.stereo and wav.shape[0] != 2:
100
+ wav = wav.transpose()
101
+
102
+ ref = wav.mean(0)
103
+ ref_mean = ref.mean()
104
+ ref_std = ref.std()
105
+ preprocessed_wav = (wav - ref_mean) / (ref_std + 1e-8)
106
+
107
+ out = self.apply_model(
108
+ self.model,
109
+ preprocessed_wav[None],
110
+ self.freq_indices,
111
+ self.num_bands_per_freq,
112
+ segment=chunk_size,
113
+ overlap=overlap,
114
+ len_model_sources=num_stems,
115
+ )
116
+
117
+ out *= ref_std + 1e-8
118
+ out += ref_mean
119
+
120
+ return out
121
+
122
+ def preprocess(self, mix):
123
+ device = torch.device("cpu")
124
+
125
+ if isinstance(mix, np.ndarray):
126
+ mix = torch.from_numpy(mix)
127
+ b, c, l = mix.shape
128
+ mix = mix.view(-1, l)
129
+
130
+ stft_window = torch.hann_window(self.stft_kwargs["win_length"], device=device)
131
+
132
+ stft_repr = torch.stft(
133
+ mix, **self.stft_kwargs, window=stft_window, return_complex=True
134
+ )
135
+ stft_repr = torch.view_as_real(stft_repr)
136
+ # print(f"stft_repr.shape: {stft_repr.shape}")
137
+
138
+ # stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
139
+
140
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
141
+ # stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
142
+ s, f, t, c = stft_repr.shape
143
+ stft_repr = (
144
+ stft_repr.unsqueeze(0)
145
+ .reshape(b, s, f, t, c)
146
+ .transpose(2, 1)
147
+ .reshape(b, -1, t, c)
148
+ )
149
+
150
+ return stft_repr.numpy()
151
+
152
+ def postprocess(
153
+ self,
154
+ masks,
155
+ stft_repr,
156
+ freq_indices,
157
+ num_bands_per_freq,
158
+ audio_len,
159
+ num_stems=4,
160
+ channels=2,
161
+ ):
162
+ masks = torch.from_numpy(masks)
163
+ stft_repr = torch.from_numpy(stft_repr)
164
+ batch = 1
165
+ istft_length = audio_len
166
+
167
+ device = torch.device("cpu")
168
+ stft_window = torch.hann_window(self.stft_kwargs["win_length"], device=device)
169
+
170
+ # modulate frequency representation
171
+
172
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
173
+
174
+ # complex number multiplication
175
+
176
+ stft_repr = torch.view_as_complex(stft_repr)
177
+ masks = torch.view_as_complex(masks)
178
+
179
+ masks = masks.type(stft_repr.dtype)
180
+
181
+ # need to average the estimated mask for the overlapped frequencies
182
+
183
+ scatter_indices = repeat(
184
+ freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1]
185
+ )
186
+
187
+ stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
188
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
189
+ 2, scatter_indices, masks
190
+ )
191
+
192
+ denom = repeat(num_bands_per_freq, "f -> (f r) 1", r=channels)
193
+ # print(f"stft_repr.shape: {stft_repr.shape}")
194
+ # print(f"stft_repr_expanded_stems.shape: {stft_repr_expanded_stems.shape}")
195
+ # print(f"masks_summed.shape: {masks_summed.shape}")
196
+ # print(f"denom.shape: {denom.shape}")
197
+
198
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
199
+
200
+ # modulate stft repr with estimated mask
201
+
202
+ stft_repr = stft_repr * masks_averaged
203
+
204
+ # istft
205
+
206
+ stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=2)
207
+
208
+ recon_audio = torch.istft(
209
+ stft_repr,
210
+ **self.stft_kwargs,
211
+ window=stft_window,
212
+ return_complex=False,
213
+ length=istft_length
214
+ )
215
+
216
+ recon_audio = rearrange(
217
+ recon_audio, "(b n s) t -> b n s t", b=batch, s=2, n=num_stems
218
+ )
219
+
220
+ if num_stems == 1:
221
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
222
+
223
+ return recon_audio.numpy()
224
+
225
+ def apply_model(
226
+ self,
227
+ model,
228
+ mix,
229
+ freq_indices,
230
+ num_bands_per_freq,
231
+ segment,
232
+ overlap: float = 0.25,
233
+ len_model_sources=4,
234
+ ):
235
+ model_weights = [1.0] * len_model_sources
236
+ totals = [0.0] * len_model_sources
237
+ batch, channels, length = mix.shape
238
+
239
+ stride = int((1 - overlap) * segment)
240
+ futures = []
241
+
242
+ for offset in tqdm.tqdm(range(0, length, stride)):
243
+ chunk = mix[..., offset : offset + segment]
244
+ audio_len = chunk.shape[-1]
245
+ if chunk.shape[-1] < segment:
246
+ chunk = np.concatenate(
247
+ [
248
+ chunk,
249
+ np.zeros(
250
+ (batch, channels, segment - chunk.shape[-1]),
251
+ dtype=np.float32,
252
+ ),
253
+ ],
254
+ axis=-1,
255
+ )
256
+
257
+ stft_input = self.preprocess(chunk)
258
+ masks = model.run(None, {"stft_input": stft_input})[0]
259
+ future = self.postprocess(
260
+ masks,
261
+ stft_input,
262
+ freq_indices,
263
+ num_bands_per_freq,
264
+ audio_len,
265
+ num_stems=len_model_sources,
266
+ )
267
+ future = future[..., :audio_len]
268
+
269
+ futures.append((future, offset))
270
+
271
+ out = np.zeros((batch, len_model_sources, channels, length))
272
+ sum_weight = np.zeros((length,))
273
+ weight = np.concatenate(
274
+ [
275
+ np.arange(1, segment // 2 + 1),
276
+ np.arange(segment - segment // 2, 0, -1),
277
+ ],
278
+ axis=-1,
279
+ )
280
+ weight = weight / weight.max()
281
+ for future, offset in futures:
282
+ chunk_out = future
283
+ chunk_length = chunk_out.shape[-1]
284
+ out[..., offset : offset + segment] += weight[:chunk_length] * chunk_out
285
+ sum_weight[offset : offset + segment] += weight[:chunk_length]
286
+ out /= sum_weight
287
+
288
+ for k, inst_weight in enumerate(model_weights):
289
+ out[:, k, :, :] *= inst_weight
290
+ totals[k] += inst_weight
291
+ for k in range(out.shape[1]):
292
+ out[:, k, :, :] /= totals[k]
293
+ return out[0]
config.json ADDED
File without changes
gradio_app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import soundfile as sf
4
+ import os
5
+ from MelBandRoformer import MelBandRoformer
6
+
7
+ model = MelBandRoformer("./mel_band_roformer.axmodel")
8
+ print("Load model finish")
9
+
10
+
11
+ def cleanup_temp_files(files):
12
+ for file in files:
13
+ if os.path.exists(file):
14
+ if os.path.isdir(file):
15
+ os.system(f"rm -rf {file}")
16
+ else:
17
+ os.remove(file)
18
+
19
+
20
+ def process_audio(input_file, pr=gr.Progress(track_tqdm=True)):
21
+ global model
22
+
23
+ output_path = "output"
24
+ cleanup_temp_files([output_path])
25
+
26
+ print("Running model")
27
+ out = model.infer(input_file)
28
+
29
+ audio_name = os.path.splitext(os.path.basename(input_file))[0]
30
+ os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)
31
+
32
+ stem_names = ["drums", "bass", "other", "vocals"]
33
+ output_files = []
34
+ print("Saving audio...")
35
+ for i in range(out.shape[0]):
36
+ source = out[i]
37
+ source = source / max(1.01 * np.abs(source).max(), 1)
38
+
39
+ if source.shape[1] != 2:
40
+ source = source.transpose()
41
+
42
+ audio_path = os.path.join(
43
+ output_path,
44
+ audio_name,
45
+ f"{stem_names[i]}.wav",
46
+ )
47
+ print(f"Save {stem_names[i]} to {audio_path}")
48
+
49
+ sf.write(audio_path, source, samplerate=model.sample_rate)
50
+ output_files.append(audio_path)
51
+
52
+ return [
53
+ gr.Audio(output_files[0], type="filepath", sources=None, editable=False),
54
+ gr.Audio(output_files[1], type="filepath", sources=None, editable=False),
55
+ gr.Audio(output_files[2], type="filepath", sources=None, editable=False),
56
+ gr.Audio(output_files[3], type="filepath", sources=None, editable=False),
57
+ ]
58
+
59
+
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("## 音轨分离")
62
+ gr.Markdown("上传一个 WAV 文件,模型将其分为drums、bass、other、vocal四轨,对应四种乐器")
63
+
64
+ audio_input = gr.Audio(type="filepath", label="上传 WAV 文件", editable=False)
65
+
66
+ with gr.Tab("Drums"):
67
+ drums_audio = gr.Audio(type="filepath", label="drums")
68
+
69
+ with gr.Tab("Bass"):
70
+ bass_audio = gr.Audio(type="filepath", label="bass")
71
+
72
+ with gr.Tab("Other"):
73
+ other_audio = gr.Audio(type="filepath", label="other")
74
+
75
+ with gr.Tab("Vocals"):
76
+ vocals_audio = gr.Audio(type="filepath", label="vocals")
77
+
78
+ submit_btn = gr.Button("处理音频")
79
+
80
+ submit_btn.click(
81
+ fn=process_audio,
82
+ inputs=[audio_input],
83
+ outputs=[drums_audio, bass_audio, other_audio, vocals_audio],
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch(server_name="0.0.0.0")
main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ import os
4
+ import soundfile as sf
5
+ import glob
6
+ from MelBandRoformer import MelBandRoformer
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)"
13
+ )
14
+ parser.add_argument(
15
+ "--output_path",
16
+ "-o",
17
+ type=str,
18
+ required=False,
19
+ default="./output",
20
+ help="Seperated wav path",
21
+ )
22
+ parser.add_argument(
23
+ "--model_path",
24
+ "-m",
25
+ type=str,
26
+ required=False,
27
+ default="./mel_band_roformer.axmodel",
28
+ )
29
+ parser.add_argument("--overlap", type=float, required=False, default=0.25)
30
+ parser.add_argument(
31
+ "--segment",
32
+ type=float,
33
+ required=False,
34
+ default=88200,
35
+ help="num samples of model",
36
+ )
37
+ parser.add_argument(
38
+ "--num_stems", type=int, default=4, help="num of instruments of model"
39
+ )
40
+ parser.add_argument("--sample_rate", type=int, default=44100)
41
+ parser.add_argument("--n_fft", type=int, default=2048)
42
+ parser.add_argument("--hop_len", type=int, default=441)
43
+ return parser.parse_args()
44
+
45
+
46
+ def main():
47
+ args = get_args()
48
+ assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist"
49
+ assert os.path.exists(args.model_path), f"Model {args.model_path} not exist"
50
+ os.makedirs(args.output_path, exist_ok=True)
51
+
52
+ input_audio = args.input_audio
53
+ output_path = args.output_path
54
+ model_path = args.model_path
55
+ segment = args.segment
56
+ num_stems = args.num_stems
57
+ target_sr = args.sample_rate
58
+
59
+ print(f"Input audio: {input_audio}")
60
+ print(f"Output path: {output_path}")
61
+ print(f"Model: {model_path}")
62
+ print(f"Overlap: {args.overlap}")
63
+
64
+ if os.path.isdir(input_audio):
65
+ types = ("*.wav", "*.mp3", "*.flac") # the tuple of file types
66
+ input_audios = []
67
+ for files in types:
68
+ input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True))
69
+ else:
70
+ input_audios = [input_audio]
71
+
72
+ mel_band = MelBandRoformer(
73
+ model_path,
74
+ stft_n_fft=args.n_fft,
75
+ stft_win_length=args.n_fft,
76
+ stft_hop_length=args.hop_len,
77
+ sample_rate=target_sr,
78
+ )
79
+
80
+ for input_audio in input_audios:
81
+ out = mel_band.infer(
82
+ input_audio,
83
+ chunk_size=segment,
84
+ overlap=args.overlap,
85
+ num_stems=num_stems,
86
+ )
87
+
88
+ audio_name = os.path.splitext(os.path.basename(input_audio))[0]
89
+ os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)
90
+
91
+ stem_names = ["drums", "bass", "other", "vocals"]
92
+ print("Saving audio...")
93
+ for i in range(out.shape[0]):
94
+ source = out[i]
95
+ source = source / max(1.01 * np.abs(source).max(), 1)
96
+
97
+ if source.shape[1] != 2:
98
+ source = source.transpose()
99
+
100
+ if num_stems == 4:
101
+ audio_path = os.path.join(
102
+ output_path,
103
+ audio_name,
104
+ f"{stem_names[i]}.wav",
105
+ )
106
+ print(f"Save {stem_names[i]} to {audio_path}")
107
+ else:
108
+ audio_path = os.path.join(
109
+ output_path,
110
+ audio_name,
111
+ f"stem_{i}.wav",
112
+ )
113
+ print(f"Save stem {i} to {audio_path}")
114
+
115
+ sf.write(audio_path, source, samplerate=target_sr)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
mel_band_roformer.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24a10bcac63b6a90d00de19063a20660a599b961bb56ede0089fe4bfacd464b3
3
+ size 95657444
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy<2.0
2
+ soundfile==0.13.1
3
+ librosa==0.9.1
4
+ tqdm
5
+ onnxruntime
6
+ einops
7
+ torch
8
+ axengine @ git+https://github.com/AXERA-TECH/pyaxengine/releases/tag/0.1.3.rc1
9
+ gradio
screenshot.png ADDED