File size: 4,339 Bytes
c7ba938
7eddfc5
 
 
d72375d
eada018
7eddfc5
 
539931b
d72375d
ef932f5
 
 
 
67a5c80
ef932f5
 
7eddfc5
 
e244f03
d5b3d4b
ef932f5
7eddfc5
2c289d4
d5b3d4b
2c289d4
ef932f5
 
 
 
 
 
 
 
2cbb082
0757bdc
2cbb082
 
028b3ab
2c289d4
 
d5b3d4b
 
2cbb082
 
 
 
 
2c289d4
63b235b
2c289d4
2cbb082
3931696
 
 
 
 
 
 
 
d5b3d4b
 
 
 
7eddfc5
 
c45ddd4
2cbb082
fbe9264
7eddfc5
 
8c575ce
2cbb082
7eddfc5
8c575ce
b0418ba
2c289d4
d5b3d4b
2c289d4
d5b3d4b
7eddfc5
 
8c575ce
67a5c80
 
 
627ec8a
67a5c80
7eddfc5
67a5c80
9aed761
 
 
 
 
 
 
 
 
 
 
 
 
7eddfc5
8c575ce
2c289d4
d5b3d4b
67a5c80
 
8c575ce
6d6a788
8eb149a
 
2c289d4
d5b3d4b
8c575ce
6d6a788
67a5c80
 
d5b3d4b
 
67a5c80
d72375d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import os
import soundfile as sf
import numpy as np


from decode import InferencePipeline
from datahandler import AudioMixer, fix_audio_format
from omegaconf import OmegaConf

MODEL_CACHE = {
    "base_model": InferencePipeline(OmegaConf.load("config/config.yaml")),
    "iter_model": InferencePipeline(OmegaConf.load("config/config_ira.yaml"))
}

# cfg = OmegaConf.load("config/config_ira.yaml")
# inter = InferencePipeline(cfg)
datamix = AudioMixer()


def gradio_TSE(input_audio_path, enroll_audio_path1, audio_type, model_select):
    print(f"模型选择: {model_select}")
    print(f"User uploaded audio path: {input_audio_path}")
    print(f"User enroll audio path:  {enroll_audio_path1}")
    # print(f"User enroll audio path:  {enroll_audio_path2}")

    # if model_select == "base_model":
    #     cfg_path = "config/config_base.yaml"
    # elif model_select == "iter_model":
    #     cfg_path = "config/config_iter.yaml"
    # else:
    #     raise ValueError("未知模型类型")
    
    inter = MODEL_CACHE[model_select]

    audio_info = sf.info(input_audio_path)
    print(f"采样率: {audio_info.samplerate} Hz")

    eol_info = sf.info(enroll_audio_path1)
    print(f"采样率: {eol_info.samplerate} Hz")

    # eol_info = sf.info(enroll_audio_path2)
    # print(f"采样率: {eol_info.samplerate} Hz")

    input_deal_wav = fix_audio_format(input_audio_path)
    input_deal_wav_path = "deal_input.wav"
    sf.write(input_deal_wav_path, input_deal_wav, 16000)

    enroll_wav1 = fix_audio_format(enroll_audio_path1)
    eol_wav1 = "eol1.wav"
    sf.write(eol_wav1, enroll_wav1, 16000)

    # if len(enroll_wav1) > 4 * 16000:
    #     middle_start = (len(enroll_wav1) - 3 * 16000) // 2
    #     middle_end = middle_start + 3 * 16000
    #     enroll_wav2 = enroll_wav1[middle_start:middle_end]
    #     print("成功提取 enroll_wav2,长度:", len(enroll_wav2))
    #     eol_wav2 = "eol2.wav"
    #     sf.write(eol_wav2, enroll_wav2, 16000)
    # else:
    # enroll_wav2 = fix_audio_format(enroll_audio_path2)
    # print("成功导入 enroll_wav2,长度:", len(enroll_wav2))
    # eol_wav2 = "eol2.wav"
    # sf.write(eol_wav2, enroll_wav2, 16000)

    if audio_type == "clean":
        noise_folder_test = "noises/" 

        mix_path = datamix.mix_with_noise_folder(input_deal_wav_path,noise_folder_test)
        print(f"Converted clean -> mix: {mix_path}")
    else:

        mix_path = input_deal_wav_path


    est_path1 = inter.run_inference(mix_path, eol_wav1)

    # est_path2 = inter.run_inference(mix_path, eol_wav2)

    return mix_path,est_path1



with gr.Blocks() as demo:
    gr.Markdown("## Target Speaker Extraction Demo")
    gr.Markdown(
        "This demo can handle either clean audio (which we'll turn into a mix) or directly a mix audio. Due to limited training data, recording directly with a mobile browser with music/vocal noises is more likely to achieve good results"
    )

    with gr.Row():
        with gr.Column(scale=2):
            input_audio = gr.Audio(label="Upload/record your audio", type="filepath")
        with gr.Column(scale=1):
            audio_type = gr.Radio(
                choices=["clean", "mix"],
                value="clean",
                label="Input audio type?"
            )
            model_select = gr.Radio(
                choices=["base_model", "iter_model"],
                value="iter_model",
                label="Select Model Type"
            )
    with gr.Row():

        enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
        # enroll_audio2 = gr.Audio(label="Upload your second enroll audio to compare", type="filepath")
    
    with gr.Row():

        noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")

    with gr.Row():
        extracted_audio_output1 = gr.Audio(label="First enroll extracted target speaker audio", type="filepath")
        # extracted_audio_output2 = gr.Audio(label="Second enroll extracted target speaker audio", type="filepath")

    convert_button = gr.Button("Extract")
    convert_button.click(
        fn=gradio_TSE,
        inputs=[input_audio, enroll_audio1, audio_type, model_select],
        outputs=[noisy_audio_output, extracted_audio_output1]
    )

if __name__ == "__main__":
    demo.launch()