import numpy as np import argparse import os import soundfile as sf import glob from MelBandRoformer import MelBandRoformer def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)" ) parser.add_argument( "--output_path", "-o", type=str, required=False, default="./output", help="Seperated wav path", ) parser.add_argument( "--model_path", "-m", type=str, required=False, default="./mel_band_roformer.axmodel", ) parser.add_argument("--overlap", type=float, required=False, default=0.25) parser.add_argument( "--segment", type=float, required=False, default=88200, help="num samples of model", ) parser.add_argument( "--num_stems", type=int, default=4, help="num of instruments of model" ) parser.add_argument("--sample_rate", type=int, default=44100) parser.add_argument("--n_fft", type=int, default=2048) parser.add_argument("--hop_len", type=int, default=441) parser.add_argument("--output_format", type=str, choices=["wav", "mp3"], default="wav") return parser.parse_args() def main(): args = get_args() assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist" assert os.path.exists(args.model_path), f"Model {args.model_path} not exist" os.makedirs(args.output_path, exist_ok=True) input_audio = args.input_audio output_path = args.output_path model_path = args.model_path segment = args.segment num_stems = args.num_stems target_sr = args.sample_rate print(f"Input audio: {input_audio}") print(f"Output path: {output_path}") print(f"Model: {model_path}") print(f"Overlap: {args.overlap}") if os.path.isdir(input_audio): types = ("*.wav", "*.mp3", "*.flac") # the tuple of file types input_audios = [] for files in types: input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True)) else: input_audios = [input_audio] mel_band = MelBandRoformer( model_path, stft_n_fft=args.n_fft, stft_win_length=args.n_fft, stft_hop_length=args.hop_len, sample_rate=target_sr, ) for input_audio in input_audios: out = mel_band.infer( input_audio, chunk_size=segment, overlap=args.overlap, num_stems=num_stems, ) audio_name = os.path.splitext(os.path.basename(input_audio))[0] os.makedirs(os.path.join(output_path, audio_name), exist_ok=True) stem_names = ["drums", "bass", "other", "vocals"] print("Saving audio...") for i in range(out.shape[0]): source = out[i] source = source / max(1.01 * np.abs(source).max(), 1) if source.shape[1] != 2: source = source.transpose() if num_stems == 4: audio_path = os.path.join( output_path, audio_name, f"{stem_names[i]}.{args.output_format}", ) print(f"Save {stem_names[i]} to {audio_path}") else: audio_path = os.path.join( output_path, audio_name, f"stem_{i}.{args.output_format}", ) print(f"Save stem {i} to {audio_path}") if args.output_format == "mp3": sf.write(audio_path, source, samplerate=target_sr, bitrate_mode='CONSTANT', compression_level=0.99) else: sf.write(audio_path, source, samplerate=target_sr) if __name__ == "__main__": main()