swc2's picture
Update app.py
627ec8a verified
import gradio as gr
import os
import soundfile as sf
import numpy as np
from decode import InferencePipeline
from datahandler import AudioMixer, fix_audio_format
from omegaconf import OmegaConf
MODEL_CACHE = {
"base_model": InferencePipeline(OmegaConf.load("config/config.yaml")),
"iter_model": InferencePipeline(OmegaConf.load("config/config_ira.yaml"))
}
# cfg = OmegaConf.load("config/config_ira.yaml")
# inter = InferencePipeline(cfg)
datamix = AudioMixer()
def gradio_TSE(input_audio_path, enroll_audio_path1, audio_type, model_select):
print(f"模型选择: {model_select}")
print(f"User uploaded audio path: {input_audio_path}")
print(f"User enroll audio path: {enroll_audio_path1}")
# print(f"User enroll audio path: {enroll_audio_path2}")
# if model_select == "base_model":
# cfg_path = "config/config_base.yaml"
# elif model_select == "iter_model":
# cfg_path = "config/config_iter.yaml"
# else:
# raise ValueError("未知模型类型")
inter = MODEL_CACHE[model_select]
audio_info = sf.info(input_audio_path)
print(f"采样率: {audio_info.samplerate} Hz")
eol_info = sf.info(enroll_audio_path1)
print(f"采样率: {eol_info.samplerate} Hz")
# eol_info = sf.info(enroll_audio_path2)
# print(f"采样率: {eol_info.samplerate} Hz")
input_deal_wav = fix_audio_format(input_audio_path)
input_deal_wav_path = "deal_input.wav"
sf.write(input_deal_wav_path, input_deal_wav, 16000)
enroll_wav1 = fix_audio_format(enroll_audio_path1)
eol_wav1 = "eol1.wav"
sf.write(eol_wav1, enroll_wav1, 16000)
# if len(enroll_wav1) > 4 * 16000:
# middle_start = (len(enroll_wav1) - 3 * 16000) // 2
# middle_end = middle_start + 3 * 16000
# enroll_wav2 = enroll_wav1[middle_start:middle_end]
# print("成功提取 enroll_wav2,长度:", len(enroll_wav2))
# eol_wav2 = "eol2.wav"
# sf.write(eol_wav2, enroll_wav2, 16000)
# else:
# enroll_wav2 = fix_audio_format(enroll_audio_path2)
# print("成功导入 enroll_wav2,长度:", len(enroll_wav2))
# eol_wav2 = "eol2.wav"
# sf.write(eol_wav2, enroll_wav2, 16000)
if audio_type == "clean":
noise_folder_test = "noises/"
mix_path = datamix.mix_with_noise_folder(input_deal_wav_path,noise_folder_test)
print(f"Converted clean -> mix: {mix_path}")
else:
mix_path = input_deal_wav_path
est_path1 = inter.run_inference(mix_path, eol_wav1)
# est_path2 = inter.run_inference(mix_path, eol_wav2)
return mix_path,est_path1
with gr.Blocks() as demo:
gr.Markdown("## Target Speaker Extraction Demo")
gr.Markdown(
"This demo can handle either clean audio (which we'll turn into a mix) or directly a mix audio. Due to limited training data, recording directly with a mobile browser with music/vocal noises is more likely to achieve good results"
)
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(label="Upload/record your audio", type="filepath")
with gr.Column(scale=1):
audio_type = gr.Radio(
choices=["clean", "mix"],
value="clean",
label="Input audio type?"
)
model_select = gr.Radio(
choices=["base_model", "iter_model"],
value="iter_model",
label="Select Model Type"
)
with gr.Row():
enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
# enroll_audio2 = gr.Audio(label="Upload your second enroll audio to compare", type="filepath")
with gr.Row():
noisy_audio_output = gr.Audio(label="Noisy Audio (Processed input audio)", type="filepath")
with gr.Row():
extracted_audio_output1 = gr.Audio(label="First enroll extracted target speaker audio", type="filepath")
# extracted_audio_output2 = gr.Audio(label="Second enroll extracted target speaker audio", type="filepath")
convert_button = gr.Button("Extract")
convert_button.click(
fn=gradio_TSE,
inputs=[input_audio, enroll_audio1, audio_type, model_select],
outputs=[noisy_audio_output, extracted_audio_output1]
)
if __name__ == "__main__":
demo.launch()