Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import soundfile as sf | |
| import numpy as np | |
| from decode import InferencePipeline | |
| from datahandler import AudioMixer, fix_audio_format | |
| from omegaconf import OmegaConf | |
| MODEL_CACHE = { | |
| "base_model": InferencePipeline(OmegaConf.load("config/config.yaml")), | |
| "iter_model": InferencePipeline(OmegaConf.load("config/config_ira.yaml")) | |
| } | |
| # cfg = OmegaConf.load("config/config_ira.yaml") | |
| # inter = InferencePipeline(cfg) | |
| datamix = AudioMixer() | |
| def gradio_TSE(input_audio_path, enroll_audio_path1, audio_type, model_select): | |
| print(f"模型选择: {model_select}") | |
| print(f"User uploaded audio path: {input_audio_path}") | |
| print(f"User enroll audio path: {enroll_audio_path1}") | |
| # print(f"User enroll audio path: {enroll_audio_path2}") | |
| # if model_select == "base_model": | |
| # cfg_path = "config/config_base.yaml" | |
| # elif model_select == "iter_model": | |
| # cfg_path = "config/config_iter.yaml" | |
| # else: | |
| # raise ValueError("未知模型类型") | |
| inter = MODEL_CACHE[model_select] | |
| audio_info = sf.info(input_audio_path) | |
| print(f"采样率: {audio_info.samplerate} Hz") | |
| eol_info = sf.info(enroll_audio_path1) | |
| print(f"采样率: {eol_info.samplerate} Hz") | |
| # eol_info = sf.info(enroll_audio_path2) | |
| # print(f"采样率: {eol_info.samplerate} Hz") | |
| input_deal_wav = fix_audio_format(input_audio_path) | |
| input_deal_wav_path = "deal_input.wav" | |
| sf.write(input_deal_wav_path, input_deal_wav, 16000) | |
| enroll_wav1 = fix_audio_format(enroll_audio_path1) | |
| eol_wav1 = "eol1.wav" | |
| sf.write(eol_wav1, enroll_wav1, 16000) | |
| # if len(enroll_wav1) > 4 * 16000: | |
| # middle_start = (len(enroll_wav1) - 3 * 16000) // 2 | |
| # middle_end = middle_start + 3 * 16000 | |
| # enroll_wav2 = enroll_wav1[middle_start:middle_end] | |
| # print("成功提取 enroll_wav2,长度:", len(enroll_wav2)) | |
| # eol_wav2 = "eol2.wav" | |
| # sf.write(eol_wav2, enroll_wav2, 16000) | |
| # else: | |
| # enroll_wav2 = fix_audio_format(enroll_audio_path2) | |
| # print("成功导入 enroll_wav2,长度:", len(enroll_wav2)) | |
| # eol_wav2 = "eol2.wav" | |
| # sf.write(eol_wav2, enroll_wav2, 16000) | |
| if audio_type == "clean": | |
| noise_folder_test = "noises/" | |
| mix_path = datamix.mix_with_noise_folder(input_deal_wav_path,noise_folder_test) | |
| print(f"Converted clean -> mix: {mix_path}") | |
| else: | |
| mix_path = input_deal_wav_path | |
| est_path1 = inter.run_inference(mix_path, eol_wav1) | |
| # est_path2 = inter.run_inference(mix_path, eol_wav2) | |
| return mix_path,est_path1 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Target Speaker Extraction Demo") | |
| gr.Markdown( | |
| "This demo can handle either clean audio (which we'll turn into a mix) or directly a mix audio. Due to limited training data, recording directly with a mobile browser with music/vocal noises is more likely to achieve good results" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_audio = gr.Audio(label="Upload/record your audio", type="filepath") | |
| with gr.Column(scale=1): | |
| audio_type = gr.Radio( | |
| choices=["clean", "mix"], | |
| value="clean", | |
| label="Input audio type?" | |
| ) | |
| model_select = gr.Radio( | |
| choices=["base_model", "iter_model"], | |
| value="iter_model", | |
| label="Select Model Type" | |
| ) | |
| with gr.Row(): | |
| enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath") | |
| # enroll_audio2 = gr.Audio(label="Upload your second enroll audio to compare", type="filepath") | |
| with gr.Row(): | |
| noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath") | |
| with gr.Row(): | |
| extracted_audio_output1 = gr.Audio(label="First enroll extracted target speaker audio", type="filepath") | |
| # extracted_audio_output2 = gr.Audio(label="Second enroll extracted target speaker audio", type="filepath") | |
| convert_button = gr.Button("Extract") | |
| convert_button.click( | |
| fn=gradio_TSE, | |
| inputs=[input_audio, enroll_audio1, audio_type, model_select], | |
| outputs=[noisy_audio_output, extracted_audio_output1] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |