|
|
import numpy as np |
|
|
import argparse |
|
|
import os |
|
|
import soundfile as sf |
|
|
import glob |
|
|
from MelBandRoformer import MelBandRoformer |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_path", |
|
|
"-o", |
|
|
type=str, |
|
|
required=False, |
|
|
default="./output", |
|
|
help="Seperated wav path", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
"-m", |
|
|
type=str, |
|
|
required=False, |
|
|
default="./mel_band_roformer.axmodel", |
|
|
) |
|
|
parser.add_argument("--overlap", type=float, required=False, default=0.25) |
|
|
parser.add_argument( |
|
|
"--segment", |
|
|
type=float, |
|
|
required=False, |
|
|
default=88200, |
|
|
help="num samples of model", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_stems", type=int, default=4, help="num of instruments of model" |
|
|
) |
|
|
parser.add_argument("--sample_rate", type=int, default=44100) |
|
|
parser.add_argument("--n_fft", type=int, default=2048) |
|
|
parser.add_argument("--hop_len", type=int, default=441) |
|
|
parser.add_argument("--output_format", type=str, choices=["wav", "mp3"], default="wav") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = get_args() |
|
|
assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist" |
|
|
assert os.path.exists(args.model_path), f"Model {args.model_path} not exist" |
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
|
|
input_audio = args.input_audio |
|
|
output_path = args.output_path |
|
|
model_path = args.model_path |
|
|
segment = args.segment |
|
|
num_stems = args.num_stems |
|
|
target_sr = args.sample_rate |
|
|
|
|
|
print(f"Input audio: {input_audio}") |
|
|
print(f"Output path: {output_path}") |
|
|
print(f"Model: {model_path}") |
|
|
print(f"Overlap: {args.overlap}") |
|
|
|
|
|
if os.path.isdir(input_audio): |
|
|
types = ("*.wav", "*.mp3", "*.flac") |
|
|
input_audios = [] |
|
|
for files in types: |
|
|
input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True)) |
|
|
else: |
|
|
input_audios = [input_audio] |
|
|
|
|
|
mel_band = MelBandRoformer( |
|
|
model_path, |
|
|
stft_n_fft=args.n_fft, |
|
|
stft_win_length=args.n_fft, |
|
|
stft_hop_length=args.hop_len, |
|
|
sample_rate=target_sr, |
|
|
) |
|
|
|
|
|
for input_audio in input_audios: |
|
|
out = mel_band.infer( |
|
|
input_audio, |
|
|
chunk_size=segment, |
|
|
overlap=args.overlap, |
|
|
num_stems=num_stems, |
|
|
) |
|
|
|
|
|
audio_name = os.path.splitext(os.path.basename(input_audio))[0] |
|
|
os.makedirs(os.path.join(output_path, audio_name), exist_ok=True) |
|
|
|
|
|
stem_names = ["drums", "bass", "other", "vocals"] |
|
|
print("Saving audio...") |
|
|
for i in range(out.shape[0]): |
|
|
source = out[i] |
|
|
source = source / max(1.01 * np.abs(source).max(), 1) |
|
|
|
|
|
if source.shape[1] != 2: |
|
|
source = source.transpose() |
|
|
|
|
|
if num_stems == 4: |
|
|
audio_path = os.path.join( |
|
|
output_path, |
|
|
audio_name, |
|
|
f"{stem_names[i]}.{args.output_format}", |
|
|
) |
|
|
print(f"Save {stem_names[i]} to {audio_path}") |
|
|
else: |
|
|
audio_path = os.path.join( |
|
|
output_path, |
|
|
audio_name, |
|
|
f"stem_{i}.{args.output_format}", |
|
|
) |
|
|
print(f"Save stem {i} to {audio_path}") |
|
|
|
|
|
if args.output_format == "mp3": |
|
|
sf.write(audio_path, source, samplerate=target_sr, bitrate_mode='CONSTANT', compression_level=0.99) |
|
|
else: |
|
|
sf.write(audio_path, source, samplerate=target_sr) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|