Upload folder using huggingface_hub
Browse files- .gitignore +2 -0
- MelBandRoformer.py +293 -0
- config.json +0 -0
- gradio_app.py +87 -0
- main.py +119 -0
- mel_band_roformer.axmodel +3 -0
- requirements.txt +9 -0
- screenshot.png +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.wav
|
MelBandRoformer.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import axengine as axe
|
| 2 |
+
import numpy as np
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import librosa
|
| 5 |
+
import torch
|
| 6 |
+
import tqdm
|
| 7 |
+
from librosa import filters
|
| 8 |
+
from einops import rearrange, reduce, repeat
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MelBandRoformer:
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model_path,
|
| 16 |
+
*,
|
| 17 |
+
stft_n_fft=2048,
|
| 18 |
+
stft_win_length=2048,
|
| 19 |
+
stft_hop_length=441,
|
| 20 |
+
stft_normalized=False,
|
| 21 |
+
sample_rate=44100,
|
| 22 |
+
num_bands=60,
|
| 23 |
+
stereo=True
|
| 24 |
+
):
|
| 25 |
+
self.stft_kwargs = dict(
|
| 26 |
+
n_fft=stft_n_fft,
|
| 27 |
+
hop_length=stft_hop_length,
|
| 28 |
+
win_length=stft_win_length,
|
| 29 |
+
normalized=stft_normalized,
|
| 30 |
+
)
|
| 31 |
+
self.sample_rate = sample_rate
|
| 32 |
+
self.num_bands = num_bands
|
| 33 |
+
self.stereo = stereo
|
| 34 |
+
self.num_channels = 2 if stereo else 1
|
| 35 |
+
|
| 36 |
+
self.freq_indices, _, _, self.num_bands_per_freq = self.calc_freq_indices()
|
| 37 |
+
|
| 38 |
+
self.model = axe.InferenceSession(
|
| 39 |
+
model_path,
|
| 40 |
+
providers=["AxEngineExecutionProvider", "AXCLRTExecutionProvider"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def calc_freq_indices(self):
|
| 44 |
+
freqs = torch.stft(
|
| 45 |
+
torch.randn(1, 4096),
|
| 46 |
+
**self.stft_kwargs,
|
| 47 |
+
window=torch.ones(self.stft_kwargs["n_fft"]),
|
| 48 |
+
return_complex=True
|
| 49 |
+
).shape[1]
|
| 50 |
+
|
| 51 |
+
# create mel filter bank
|
| 52 |
+
# with librosa.filters.mel as in section 2 of paper
|
| 53 |
+
|
| 54 |
+
mel_filter_bank_numpy = filters.mel(
|
| 55 |
+
sr=self.sample_rate, n_fft=self.stft_kwargs["n_fft"], n_mels=self.num_bands
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
| 59 |
+
|
| 60 |
+
# for some reason, it doesn't include the first freq? just force a value for now
|
| 61 |
+
|
| 62 |
+
mel_filter_bank[0][0] = 1.0
|
| 63 |
+
|
| 64 |
+
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
| 65 |
+
# so let's force a positive value
|
| 66 |
+
|
| 67 |
+
mel_filter_bank[-1, -1] = 1.0
|
| 68 |
+
|
| 69 |
+
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
| 70 |
+
|
| 71 |
+
freqs_per_band = mel_filter_bank > 0
|
| 72 |
+
assert freqs_per_band.any(
|
| 73 |
+
dim=0
|
| 74 |
+
).all(), "all frequencies need to be covered by all bands for now"
|
| 75 |
+
|
| 76 |
+
repeated_freq_indices = repeat(
|
| 77 |
+
torch.arange(freqs), "f -> b f", b=self.num_bands
|
| 78 |
+
)
|
| 79 |
+
freq_indices = repeated_freq_indices[freqs_per_band]
|
| 80 |
+
|
| 81 |
+
if self.stereo:
|
| 82 |
+
freq_indices = repeat(freq_indices, "f -> f s", s=2)
|
| 83 |
+
freq_indices = freq_indices * 2 + torch.arange(2)
|
| 84 |
+
freq_indices = rearrange(freq_indices, "f s -> (f s)")
|
| 85 |
+
|
| 86 |
+
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
|
| 87 |
+
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
|
| 88 |
+
|
| 89 |
+
return freq_indices, freqs_per_band, num_freqs_per_band, num_bands_per_freq
|
| 90 |
+
|
| 91 |
+
def infer(
|
| 92 |
+
self, audio: Union[str, np.ndarray], chunk_size=88200, overlap=0.25, num_stems=4
|
| 93 |
+
):
|
| 94 |
+
if isinstance(audio, str):
|
| 95 |
+
wav, _ = librosa.load(audio, sr=self.sample_rate, mono=not self.stereo)
|
| 96 |
+
else:
|
| 97 |
+
wav = audio
|
| 98 |
+
|
| 99 |
+
if self.stereo and wav.shape[0] != 2:
|
| 100 |
+
wav = wav.transpose()
|
| 101 |
+
|
| 102 |
+
ref = wav.mean(0)
|
| 103 |
+
ref_mean = ref.mean()
|
| 104 |
+
ref_std = ref.std()
|
| 105 |
+
preprocessed_wav = (wav - ref_mean) / (ref_std + 1e-8)
|
| 106 |
+
|
| 107 |
+
out = self.apply_model(
|
| 108 |
+
self.model,
|
| 109 |
+
preprocessed_wav[None],
|
| 110 |
+
self.freq_indices,
|
| 111 |
+
self.num_bands_per_freq,
|
| 112 |
+
segment=chunk_size,
|
| 113 |
+
overlap=overlap,
|
| 114 |
+
len_model_sources=num_stems,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
out *= ref_std + 1e-8
|
| 118 |
+
out += ref_mean
|
| 119 |
+
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
def preprocess(self, mix):
|
| 123 |
+
device = torch.device("cpu")
|
| 124 |
+
|
| 125 |
+
if isinstance(mix, np.ndarray):
|
| 126 |
+
mix = torch.from_numpy(mix)
|
| 127 |
+
b, c, l = mix.shape
|
| 128 |
+
mix = mix.view(-1, l)
|
| 129 |
+
|
| 130 |
+
stft_window = torch.hann_window(self.stft_kwargs["win_length"], device=device)
|
| 131 |
+
|
| 132 |
+
stft_repr = torch.stft(
|
| 133 |
+
mix, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 134 |
+
)
|
| 135 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 136 |
+
# print(f"stft_repr.shape: {stft_repr.shape}")
|
| 137 |
+
|
| 138 |
+
# stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
| 139 |
+
|
| 140 |
+
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
| 141 |
+
# stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
| 142 |
+
s, f, t, c = stft_repr.shape
|
| 143 |
+
stft_repr = (
|
| 144 |
+
stft_repr.unsqueeze(0)
|
| 145 |
+
.reshape(b, s, f, t, c)
|
| 146 |
+
.transpose(2, 1)
|
| 147 |
+
.reshape(b, -1, t, c)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return stft_repr.numpy()
|
| 151 |
+
|
| 152 |
+
def postprocess(
|
| 153 |
+
self,
|
| 154 |
+
masks,
|
| 155 |
+
stft_repr,
|
| 156 |
+
freq_indices,
|
| 157 |
+
num_bands_per_freq,
|
| 158 |
+
audio_len,
|
| 159 |
+
num_stems=4,
|
| 160 |
+
channels=2,
|
| 161 |
+
):
|
| 162 |
+
masks = torch.from_numpy(masks)
|
| 163 |
+
stft_repr = torch.from_numpy(stft_repr)
|
| 164 |
+
batch = 1
|
| 165 |
+
istft_length = audio_len
|
| 166 |
+
|
| 167 |
+
device = torch.device("cpu")
|
| 168 |
+
stft_window = torch.hann_window(self.stft_kwargs["win_length"], device=device)
|
| 169 |
+
|
| 170 |
+
# modulate frequency representation
|
| 171 |
+
|
| 172 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 173 |
+
|
| 174 |
+
# complex number multiplication
|
| 175 |
+
|
| 176 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 177 |
+
masks = torch.view_as_complex(masks)
|
| 178 |
+
|
| 179 |
+
masks = masks.type(stft_repr.dtype)
|
| 180 |
+
|
| 181 |
+
# need to average the estimated mask for the overlapped frequencies
|
| 182 |
+
|
| 183 |
+
scatter_indices = repeat(
|
| 184 |
+
freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1]
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
|
| 188 |
+
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
|
| 189 |
+
2, scatter_indices, masks
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
denom = repeat(num_bands_per_freq, "f -> (f r) 1", r=channels)
|
| 193 |
+
# print(f"stft_repr.shape: {stft_repr.shape}")
|
| 194 |
+
# print(f"stft_repr_expanded_stems.shape: {stft_repr_expanded_stems.shape}")
|
| 195 |
+
# print(f"masks_summed.shape: {masks_summed.shape}")
|
| 196 |
+
# print(f"denom.shape: {denom.shape}")
|
| 197 |
+
|
| 198 |
+
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
| 199 |
+
|
| 200 |
+
# modulate stft repr with estimated mask
|
| 201 |
+
|
| 202 |
+
stft_repr = stft_repr * masks_averaged
|
| 203 |
+
|
| 204 |
+
# istft
|
| 205 |
+
|
| 206 |
+
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=2)
|
| 207 |
+
|
| 208 |
+
recon_audio = torch.istft(
|
| 209 |
+
stft_repr,
|
| 210 |
+
**self.stft_kwargs,
|
| 211 |
+
window=stft_window,
|
| 212 |
+
return_complex=False,
|
| 213 |
+
length=istft_length
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
recon_audio = rearrange(
|
| 217 |
+
recon_audio, "(b n s) t -> b n s t", b=batch, s=2, n=num_stems
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if num_stems == 1:
|
| 221 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 222 |
+
|
| 223 |
+
return recon_audio.numpy()
|
| 224 |
+
|
| 225 |
+
def apply_model(
|
| 226 |
+
self,
|
| 227 |
+
model,
|
| 228 |
+
mix,
|
| 229 |
+
freq_indices,
|
| 230 |
+
num_bands_per_freq,
|
| 231 |
+
segment,
|
| 232 |
+
overlap: float = 0.25,
|
| 233 |
+
len_model_sources=4,
|
| 234 |
+
):
|
| 235 |
+
model_weights = [1.0] * len_model_sources
|
| 236 |
+
totals = [0.0] * len_model_sources
|
| 237 |
+
batch, channels, length = mix.shape
|
| 238 |
+
|
| 239 |
+
stride = int((1 - overlap) * segment)
|
| 240 |
+
futures = []
|
| 241 |
+
|
| 242 |
+
for offset in tqdm.tqdm(range(0, length, stride)):
|
| 243 |
+
chunk = mix[..., offset : offset + segment]
|
| 244 |
+
audio_len = chunk.shape[-1]
|
| 245 |
+
if chunk.shape[-1] < segment:
|
| 246 |
+
chunk = np.concatenate(
|
| 247 |
+
[
|
| 248 |
+
chunk,
|
| 249 |
+
np.zeros(
|
| 250 |
+
(batch, channels, segment - chunk.shape[-1]),
|
| 251 |
+
dtype=np.float32,
|
| 252 |
+
),
|
| 253 |
+
],
|
| 254 |
+
axis=-1,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
stft_input = self.preprocess(chunk)
|
| 258 |
+
masks = model.run(None, {"stft_input": stft_input})[0]
|
| 259 |
+
future = self.postprocess(
|
| 260 |
+
masks,
|
| 261 |
+
stft_input,
|
| 262 |
+
freq_indices,
|
| 263 |
+
num_bands_per_freq,
|
| 264 |
+
audio_len,
|
| 265 |
+
num_stems=len_model_sources,
|
| 266 |
+
)
|
| 267 |
+
future = future[..., :audio_len]
|
| 268 |
+
|
| 269 |
+
futures.append((future, offset))
|
| 270 |
+
|
| 271 |
+
out = np.zeros((batch, len_model_sources, channels, length))
|
| 272 |
+
sum_weight = np.zeros((length,))
|
| 273 |
+
weight = np.concatenate(
|
| 274 |
+
[
|
| 275 |
+
np.arange(1, segment // 2 + 1),
|
| 276 |
+
np.arange(segment - segment // 2, 0, -1),
|
| 277 |
+
],
|
| 278 |
+
axis=-1,
|
| 279 |
+
)
|
| 280 |
+
weight = weight / weight.max()
|
| 281 |
+
for future, offset in futures:
|
| 282 |
+
chunk_out = future
|
| 283 |
+
chunk_length = chunk_out.shape[-1]
|
| 284 |
+
out[..., offset : offset + segment] += weight[:chunk_length] * chunk_out
|
| 285 |
+
sum_weight[offset : offset + segment] += weight[:chunk_length]
|
| 286 |
+
out /= sum_weight
|
| 287 |
+
|
| 288 |
+
for k, inst_weight in enumerate(model_weights):
|
| 289 |
+
out[:, k, :, :] *= inst_weight
|
| 290 |
+
totals[k] += inst_weight
|
| 291 |
+
for k in range(out.shape[1]):
|
| 292 |
+
out[:, k, :, :] /= totals[k]
|
| 293 |
+
return out[0]
|
config.json
ADDED
|
File without changes
|
gradio_app.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import os
|
| 5 |
+
from MelBandRoformer import MelBandRoformer
|
| 6 |
+
|
| 7 |
+
model = MelBandRoformer("./mel_band_roformer.axmodel")
|
| 8 |
+
print("Load model finish")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cleanup_temp_files(files):
|
| 12 |
+
for file in files:
|
| 13 |
+
if os.path.exists(file):
|
| 14 |
+
if os.path.isdir(file):
|
| 15 |
+
os.system(f"rm -rf {file}")
|
| 16 |
+
else:
|
| 17 |
+
os.remove(file)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def process_audio(input_file, pr=gr.Progress(track_tqdm=True)):
|
| 21 |
+
global model
|
| 22 |
+
|
| 23 |
+
output_path = "output"
|
| 24 |
+
cleanup_temp_files([output_path])
|
| 25 |
+
|
| 26 |
+
print("Running model")
|
| 27 |
+
out = model.infer(input_file)
|
| 28 |
+
|
| 29 |
+
audio_name = os.path.splitext(os.path.basename(input_file))[0]
|
| 30 |
+
os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)
|
| 31 |
+
|
| 32 |
+
stem_names = ["drums", "bass", "other", "vocals"]
|
| 33 |
+
output_files = []
|
| 34 |
+
print("Saving audio...")
|
| 35 |
+
for i in range(out.shape[0]):
|
| 36 |
+
source = out[i]
|
| 37 |
+
source = source / max(1.01 * np.abs(source).max(), 1)
|
| 38 |
+
|
| 39 |
+
if source.shape[1] != 2:
|
| 40 |
+
source = source.transpose()
|
| 41 |
+
|
| 42 |
+
audio_path = os.path.join(
|
| 43 |
+
output_path,
|
| 44 |
+
audio_name,
|
| 45 |
+
f"{stem_names[i]}.wav",
|
| 46 |
+
)
|
| 47 |
+
print(f"Save {stem_names[i]} to {audio_path}")
|
| 48 |
+
|
| 49 |
+
sf.write(audio_path, source, samplerate=model.sample_rate)
|
| 50 |
+
output_files.append(audio_path)
|
| 51 |
+
|
| 52 |
+
return [
|
| 53 |
+
gr.Audio(output_files[0], type="filepath", sources=None, editable=False),
|
| 54 |
+
gr.Audio(output_files[1], type="filepath", sources=None, editable=False),
|
| 55 |
+
gr.Audio(output_files[2], type="filepath", sources=None, editable=False),
|
| 56 |
+
gr.Audio(output_files[3], type="filepath", sources=None, editable=False),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
with gr.Blocks() as demo:
|
| 61 |
+
gr.Markdown("## 音轨分离")
|
| 62 |
+
gr.Markdown("上传一个 WAV 文件,模型将其分为drums、bass、other、vocal四轨,对应四种乐器")
|
| 63 |
+
|
| 64 |
+
audio_input = gr.Audio(type="filepath", label="上传 WAV 文件", editable=False)
|
| 65 |
+
|
| 66 |
+
with gr.Tab("Drums"):
|
| 67 |
+
drums_audio = gr.Audio(type="filepath", label="drums")
|
| 68 |
+
|
| 69 |
+
with gr.Tab("Bass"):
|
| 70 |
+
bass_audio = gr.Audio(type="filepath", label="bass")
|
| 71 |
+
|
| 72 |
+
with gr.Tab("Other"):
|
| 73 |
+
other_audio = gr.Audio(type="filepath", label="other")
|
| 74 |
+
|
| 75 |
+
with gr.Tab("Vocals"):
|
| 76 |
+
vocals_audio = gr.Audio(type="filepath", label="vocals")
|
| 77 |
+
|
| 78 |
+
submit_btn = gr.Button("处理音频")
|
| 79 |
+
|
| 80 |
+
submit_btn.click(
|
| 81 |
+
fn=process_audio,
|
| 82 |
+
inputs=[audio_input],
|
| 83 |
+
outputs=[drums_audio, bass_audio, other_audio, vocals_audio],
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
demo.launch(server_name="0.0.0.0")
|
main.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import glob
|
| 6 |
+
from MelBandRoformer import MelBandRoformer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_args():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)"
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--output_path",
|
| 16 |
+
"-o",
|
| 17 |
+
type=str,
|
| 18 |
+
required=False,
|
| 19 |
+
default="./output",
|
| 20 |
+
help="Seperated wav path",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--model_path",
|
| 24 |
+
"-m",
|
| 25 |
+
type=str,
|
| 26 |
+
required=False,
|
| 27 |
+
default="./mel_band_roformer.axmodel",
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument("--overlap", type=float, required=False, default=0.25)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--segment",
|
| 32 |
+
type=float,
|
| 33 |
+
required=False,
|
| 34 |
+
default=88200,
|
| 35 |
+
help="num samples of model",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--num_stems", type=int, default=4, help="num of instruments of model"
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--sample_rate", type=int, default=44100)
|
| 41 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
| 42 |
+
parser.add_argument("--hop_len", type=int, default=441)
|
| 43 |
+
return parser.parse_args()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main():
|
| 47 |
+
args = get_args()
|
| 48 |
+
assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist"
|
| 49 |
+
assert os.path.exists(args.model_path), f"Model {args.model_path} not exist"
|
| 50 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
input_audio = args.input_audio
|
| 53 |
+
output_path = args.output_path
|
| 54 |
+
model_path = args.model_path
|
| 55 |
+
segment = args.segment
|
| 56 |
+
num_stems = args.num_stems
|
| 57 |
+
target_sr = args.sample_rate
|
| 58 |
+
|
| 59 |
+
print(f"Input audio: {input_audio}")
|
| 60 |
+
print(f"Output path: {output_path}")
|
| 61 |
+
print(f"Model: {model_path}")
|
| 62 |
+
print(f"Overlap: {args.overlap}")
|
| 63 |
+
|
| 64 |
+
if os.path.isdir(input_audio):
|
| 65 |
+
types = ("*.wav", "*.mp3", "*.flac") # the tuple of file types
|
| 66 |
+
input_audios = []
|
| 67 |
+
for files in types:
|
| 68 |
+
input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True))
|
| 69 |
+
else:
|
| 70 |
+
input_audios = [input_audio]
|
| 71 |
+
|
| 72 |
+
mel_band = MelBandRoformer(
|
| 73 |
+
model_path,
|
| 74 |
+
stft_n_fft=args.n_fft,
|
| 75 |
+
stft_win_length=args.n_fft,
|
| 76 |
+
stft_hop_length=args.hop_len,
|
| 77 |
+
sample_rate=target_sr,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
for input_audio in input_audios:
|
| 81 |
+
out = mel_band.infer(
|
| 82 |
+
input_audio,
|
| 83 |
+
chunk_size=segment,
|
| 84 |
+
overlap=args.overlap,
|
| 85 |
+
num_stems=num_stems,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
audio_name = os.path.splitext(os.path.basename(input_audio))[0]
|
| 89 |
+
os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)
|
| 90 |
+
|
| 91 |
+
stem_names = ["drums", "bass", "other", "vocals"]
|
| 92 |
+
print("Saving audio...")
|
| 93 |
+
for i in range(out.shape[0]):
|
| 94 |
+
source = out[i]
|
| 95 |
+
source = source / max(1.01 * np.abs(source).max(), 1)
|
| 96 |
+
|
| 97 |
+
if source.shape[1] != 2:
|
| 98 |
+
source = source.transpose()
|
| 99 |
+
|
| 100 |
+
if num_stems == 4:
|
| 101 |
+
audio_path = os.path.join(
|
| 102 |
+
output_path,
|
| 103 |
+
audio_name,
|
| 104 |
+
f"{stem_names[i]}.wav",
|
| 105 |
+
)
|
| 106 |
+
print(f"Save {stem_names[i]} to {audio_path}")
|
| 107 |
+
else:
|
| 108 |
+
audio_path = os.path.join(
|
| 109 |
+
output_path,
|
| 110 |
+
audio_name,
|
| 111 |
+
f"stem_{i}.wav",
|
| 112 |
+
)
|
| 113 |
+
print(f"Save stem {i} to {audio_path}")
|
| 114 |
+
|
| 115 |
+
sf.write(audio_path, source, samplerate=target_sr)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
mel_band_roformer.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24a10bcac63b6a90d00de19063a20660a599b961bb56ede0089fe4bfacd464b3
|
| 3 |
+
size 95657444
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy<2.0
|
| 2 |
+
soundfile==0.13.1
|
| 3 |
+
librosa==0.9.1
|
| 4 |
+
tqdm
|
| 5 |
+
onnxruntime
|
| 6 |
+
einops
|
| 7 |
+
torch
|
| 8 |
+
axengine @ git+https://github.com/AXERA-TECH/pyaxengine/releases/tag/0.1.3.rc1
|
| 9 |
+
gradio
|
screenshot.png
ADDED
|