Spaces:
Runtime error
Runtime error
| # flake8: noqa: E402 | |
| import gc | |
| import os | |
| import logging | |
| import re_matching | |
| logging.getLogger("numba").setLevel(logging.WARNING) | |
| logging.getLogger("markdown_it").setLevel(logging.WARNING) | |
| logging.getLogger("urllib3").setLevel(logging.WARNING) | |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) | |
| logging.basicConfig( | |
| level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| import torch | |
| import utils | |
| from infer import infer, latest_version, get_net_g | |
| import gradio as gr | |
| # import webbrowser | |
| import numpy as np | |
| from config import config | |
| # multithreading | |
| torch.set_num_threads(os.cpu_count()) | |
| torch.set_num_interop_threads(os.cpu_count()) | |
| net_g = None | |
| device = config.device | |
| if device == "mps": | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| def free_up_memory(): | |
| # Prior inference run might have large variables not cleaned up due to exception during the run. | |
| # Free up as much memory as possible to allow this run to be successful. | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def generate_audio( | |
| slices, | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| speaker, | |
| # language, | |
| # reference_audio, | |
| # emotion, | |
| style_text, | |
| style_weight, | |
| skip_start=False, | |
| skip_end=False, | |
| ): | |
| audio_list = [] | |
| # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) | |
| free_up_memory() | |
| with torch.no_grad(): | |
| for idx, piece in enumerate(slices): | |
| skip_start = idx != 0 | |
| skip_end = idx != len(slices) - 1 | |
| audio = infer( | |
| piece, | |
| # reference_audio=reference_audio, | |
| emotion=None, | |
| sdp_ratio=sdp_ratio, | |
| noise_scale=noise_scale, | |
| noise_scale_w=noise_scale_w, | |
| length_scale=length_scale, | |
| sid=speaker, | |
| language="ZH", | |
| hps=hps, | |
| net_g=net_g, | |
| device=device, | |
| skip_start=skip_start, | |
| skip_end=skip_end, | |
| style_text=style_text, | |
| style_weight=style_weight, | |
| ) | |
| audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) | |
| audio_list.append(audio16bit) | |
| return audio_list | |
| def process_text( | |
| text: str, | |
| speaker, | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| # language, | |
| # reference_audio, | |
| # emotion, | |
| style_text=None, | |
| style_weight=0, | |
| ): | |
| audio_list = [] | |
| audio_list.extend( | |
| generate_audio( | |
| text.split("|"), | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| speaker, | |
| # language, | |
| # reference_audio, | |
| # emotion, | |
| style_text, | |
| style_weight, | |
| ) | |
| ) | |
| return audio_list | |
| def tts_fn( | |
| text: str, | |
| speaker, | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| # reference_audio, | |
| # emotion, | |
| # prompt_mode, | |
| style_text=None, | |
| style_weight=0, | |
| ): | |
| if style_text == "": | |
| style_text = None | |
| # if prompt_mode == "Audio prompt": | |
| # if reference_audio == None: | |
| # return ("Invalid audio prompt", None) | |
| # else: | |
| # reference_audio = load_audio(reference_audio)[1] | |
| # else: | |
| # reference_audio = None | |
| audio_list = process_text( | |
| text, | |
| speaker, | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| # language, | |
| # reference_audio, | |
| # emotion, | |
| style_text, | |
| style_weight, | |
| ) | |
| audio_concat = np.concatenate(audio_list) | |
| return "Success", (hps.data.sampling_rate, audio_concat) | |
| if __name__ == "__main__": | |
| if config.webui_config.debug: | |
| logger.info("Enable DEBUG-LEVEL log") | |
| logging.basicConfig(level=logging.DEBUG) | |
| hps = utils.get_hparams_from_file(config.webui_config.config_path) | |
| # 若config.json中未指定版本则默认为最新版本 | |
| version = hps.version if hasattr(hps, "version") else latest_version | |
| net_g = get_net_g( | |
| model_path=config.webui_config.model, version=version, device=device, hps=hps | |
| ) | |
| speaker_ids = hps.data.spk2id | |
| speakers = list(speaker_ids.keys()) | |
| languages = ["ZH", "JP", "EN", "mix", "auto"] | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.TextArea( | |
| label="输入文本内容", | |
| ) | |
| # trans = gr.Button("中翻日", variant="primary") | |
| # slicer = gr.Button("快速切分", variant="primary") | |
| # formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary") | |
| speaker = gr.Dropdown( | |
| choices=speakers, value=speakers[0], label="Speaker" | |
| ) | |
| # _ = gr.Markdown( | |
| # value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n", | |
| # visible=False, | |
| # ) | |
| # prompt_mode = gr.Radio( | |
| # ["Text prompt", "Audio prompt"], | |
| # label="Prompt Mode", | |
| # value="Text prompt", | |
| # visible=False, | |
| # ) | |
| # text_prompt = gr.Textbox( | |
| # label="Text prompt", | |
| # placeholder="用文字描述生成风格。如:Happy", | |
| # value="Happy", | |
| # visible=False, | |
| # ) | |
| # audio_prompt = gr.Audio( | |
| # label="Audio prompt", type="filepath", visible=False | |
| # ) | |
| sdp_ratio = gr.Slider( | |
| minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio" | |
| ) | |
| noise_scale = gr.Slider( | |
| minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise" | |
| ) | |
| noise_scale_w = gr.Slider( | |
| minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W" | |
| ) | |
| length_scale = gr.Slider( | |
| minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length" | |
| ) | |
| btn = gr.Button("生成音频!", variant="primary") | |
| with gr.Column(): | |
| with gr.Accordion("融合文本语义", open=False): | |
| gr.Markdown( | |
| value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n" | |
| "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n" | |
| "效果较不明确,留空即为不使用该功能" | |
| ) | |
| style_text = gr.Textbox(label="辅助文本") | |
| style_weight = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.7, | |
| step=0.1, | |
| label="Weight", | |
| info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本", | |
| ) | |
| text_output = gr.Textbox(label="状态信息") | |
| audio_output = gr.Audio(label="输出音频") | |
| # explain_image = gr.Image( | |
| # label="参数解释信息", | |
| # show_label=True, | |
| # show_share_button=False, | |
| # show_download_button=False, | |
| # value=os.path.abspath("./img/参数说明.png"), | |
| # ) | |
| btn.click( | |
| tts_fn, | |
| inputs=[ | |
| text, | |
| speaker, | |
| sdp_ratio, | |
| noise_scale, | |
| noise_scale_w, | |
| length_scale, | |
| # language, | |
| # audio_prompt, | |
| # text_prompt, | |
| # prompt_mode, | |
| style_text, | |
| style_weight, | |
| ], | |
| outputs=[text_output, audio_output], | |
| ) | |
| # trans.click( | |
| # translate, | |
| # inputs=[text], | |
| # outputs=[text], | |
| # ) | |
| # slicer.click( | |
| # tts_split, | |
| # inputs=[ | |
| # text, | |
| # speaker, | |
| # sdp_ratio, | |
| # noise_scale, | |
| # noise_scale_w, | |
| # length_scale, | |
| # language, | |
| # opt_cut_by_sent, | |
| # interval_between_para, | |
| # interval_between_sent, | |
| # # audio_prompt, | |
| # # text_prompt, | |
| # style_text, | |
| # style_weight, | |
| # ], | |
| # outputs=[text_output, audio_output], | |
| # ) | |
| # formatter.click( | |
| # format_utils, | |
| # inputs=[text, speaker], | |
| # outputs=[language, text], | |
| # ) | |
| print("推理页面已开启!") | |
| # webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}") | |
| app.launch(share=config.webui_config.share, server_port=config.webui_config.port) | |