File size: 3,858 Bytes
ab5bd26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9661be
ab5bd26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e44f4
ab5bd26
 
 
 
 
 
19e44f4
ab5bd26
 
 
2a5f20f
 
 
 
ab5bd26
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import argparse
import os
import soundfile as sf
import glob
from MelBandRoformer import MelBandRoformer


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)"
    )
    parser.add_argument(
        "--output_path",
        "-o",
        type=str,
        required=False,
        default="./output",
        help="Seperated wav path",
    )
    parser.add_argument(
        "--model_path",
        "-m",
        type=str,
        required=False,
        default="./mel_band_roformer.axmodel",
    )
    parser.add_argument("--overlap", type=float, required=False, default=0.25)
    parser.add_argument(
        "--segment",
        type=float,
        required=False,
        default=88200,
        help="num samples of model",
    )
    parser.add_argument(
        "--num_stems", type=int, default=4, help="num of instruments of model"
    )
    parser.add_argument("--sample_rate", type=int, default=44100)
    parser.add_argument("--n_fft", type=int, default=2048)
    parser.add_argument("--hop_len", type=int, default=441)
    parser.add_argument("--output_format", type=str, choices=["wav", "mp3"], default="wav")
    return parser.parse_args()


def main():
    args = get_args()
    assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist"
    assert os.path.exists(args.model_path), f"Model {args.model_path} not exist"
    os.makedirs(args.output_path, exist_ok=True)

    input_audio = args.input_audio
    output_path = args.output_path
    model_path = args.model_path
    segment = args.segment
    num_stems = args.num_stems
    target_sr = args.sample_rate

    print(f"Input audio: {input_audio}")
    print(f"Output path: {output_path}")
    print(f"Model: {model_path}")
    print(f"Overlap: {args.overlap}")

    if os.path.isdir(input_audio):
        types = ("*.wav", "*.mp3", "*.flac")  # the tuple of file types
        input_audios = []
        for files in types:
            input_audios.extend(glob.glob(f"{input_audio}/**/{files}", recursive=True))
    else:
        input_audios = [input_audio]

    mel_band = MelBandRoformer(
        model_path,
        stft_n_fft=args.n_fft,
        stft_win_length=args.n_fft,
        stft_hop_length=args.hop_len,
        sample_rate=target_sr,
    )

    for input_audio in input_audios:
        out = mel_band.infer(
            input_audio,
            chunk_size=segment,
            overlap=args.overlap,
            num_stems=num_stems,
        )

        audio_name = os.path.splitext(os.path.basename(input_audio))[0]
        os.makedirs(os.path.join(output_path, audio_name), exist_ok=True)

        stem_names = ["drums", "bass", "other", "vocals"]
        print("Saving audio...")
        for i in range(out.shape[0]):
            source = out[i]
            source = source / max(1.01 * np.abs(source).max(), 1)

            if source.shape[1] != 2:
                source = source.transpose()

            if num_stems == 4:
                audio_path = os.path.join(
                    output_path,
                    audio_name,
                    f"{stem_names[i]}.{args.output_format}",
                )
                print(f"Save {stem_names[i]} to {audio_path}")
            else:
                audio_path = os.path.join(
                    output_path,
                    audio_name,
                    f"stem_{i}.{args.output_format}",
                )
                print(f"Save stem {i} to {audio_path}")

            if args.output_format == "mp3":
                sf.write(audio_path, source, samplerate=target_sr, bitrate_mode='CONSTANT', compression_level=0.99)
            else:
                sf.write(audio_path, source, samplerate=target_sr)


if __name__ == "__main__":
    main()