yangrongzhao
fix sf arg
2a5f20f
import numpy as np
import argparse
import os
import soundfile as sf
import glob
from MelBandRoformer import MelBandRoformer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)"
)
parser.add_argument(
"--output_path",
"-o",
type=str,
required=False,
default="./output",
help="Seperated wav path",
)
parser.add_argument(
"--model_path",
"-m",
type=str,
required=False,
default="./mel_band_roformer.axmodel",
)
parser.add_argument("--overlap", type=float, required=False, default=0.25)
parser.add_argument(
"--segment",
type=float,
required=False,
default=88200,
help="num samples of model",
)
parser.add_argument(
"--num_stems", type=int, default=4, help="num of instruments of model"
)
parser.add_argument("--sample_rate", type=int, default=44100)
parser.add_argument("--n_fft", type=int, default=2048)
parser.add_argument("--hop_len", type=int, default=441)
parser.add_argument("--output_format", type=str, choices=["wav", "mp3"], default="wav")
return parser.parse_args()
def main():
args = get_args()
assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist"
assert os.path.exists(args.model_path), f"Model {args.model_path} not exist"
os.makedirs(args.output_path, exist_ok=True)
input_audio = args.input_audio
output_path = args.output_path
model_path = args.model_path
segment = args.segment
num_stems = args.num_stems
target_sr = args.sample_rate
print(f"Input audio: {input_audio}")
print(f"Output path: {output_path}")
print(f"Model: {model_path}")
print(f"Overlap: {args.overlap}")
if os.path.isdir(input_audio):
types = ("*.wav", "*.mp3", "*.flac") # the tuple of file types
input_audios = []
for files in types:
input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True))
else:
input_audios = [input_audio]
mel_band = MelBandRoformer(
model_path,
stft_n_fft=args.n_fft,
stft_win_length=args.n_fft,
stft_hop_length=args.hop_len,
sample_rate=target_sr,
)
for input_audio in input_audios:
out = mel_band.infer(
input_audio,
chunk_size=segment,
overlap=args.overlap,
num_stems=num_stems,
)
audio_name = os.path.splitext(os.path.basename(input_audio))[0]
os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)
stem_names = ["drums", "bass", "other", "vocals"]
print("Saving audio...")
for i in range(out.shape[0]):
source = out[i]
source = source / max(1.01 * np.abs(source).max(), 1)
if source.shape[1] != 2:
source = source.transpose()
if num_stems == 4:
audio_path = os.path.join(
output_path,
audio_name,
f"{stem_names[i]}.{args.output_format}",
)
print(f"Save {stem_names[i]} to {audio_path}")
else:
audio_path = os.path.join(
output_path,
audio_name,
f"stem_{i}.{args.output_format}",
)
print(f"Save stem {i} to {audio_path}")
if args.output_format == "mp3":
sf.write(audio_path, source, samplerate=target_sr, bitrate_mode='CONSTANT', compression_level=0.99)
else:
sf.write(audio_path, source, samplerate=target_sr)
if __name__ == "__main__":
main()