Bert-VITS2 / webui.py
Your Name
multithreading
f175c06
# flake8: noqa: E402
import gc
import os
import logging
import re_matching
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
import torch
import utils
from infer import infer, latest_version, get_net_g
import gradio as gr
# import webbrowser
import numpy as np
from config import config
# multithreading
torch.set_num_threads(os.cpu_count())
torch.set_num_interop_threads(os.cpu_count())
net_g = None
device = config.device
if device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def free_up_memory():
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate_audio(
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
# reference_audio=reference_audio,
emotion=None,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language="ZH",
hps=hps,
net_g=net_g,
device=device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# reference_audio,
# emotion,
style_text=None,
style_weight=0,
):
audio_list = []
audio_list.extend(
generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_fn(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# reference_audio,
# emotion,
# prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
# if prompt_mode == "Audio prompt":
# if reference_audio == None:
# return ("Invalid audio prompt", None)
# else:
# reference_audio = load_audio(reference_audio)[1]
# else:
# reference_audio = None
audio_list = process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (hps.data.sampling_rate, audio_concat)
if __name__ == "__main__":
if config.webui_config.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
hps = utils.get_hparams_from_file(config.webui_config.config_path)
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = ["ZH", "JP", "EN", "mix", "auto"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="输入文本内容",
)
# trans = gr.Button("中翻日", variant="primary")
# slicer = gr.Button("快速切分", variant="primary")
# formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
speaker = gr.Dropdown(
choices=speakers, value=speakers[0], label="Speaker"
)
# _ = gr.Markdown(
# value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
# visible=False,
# )
# prompt_mode = gr.Radio(
# ["Text prompt", "Audio prompt"],
# label="Prompt Mode",
# value="Text prompt",
# visible=False,
# )
# text_prompt = gr.Textbox(
# label="Text prompt",
# placeholder="用文字描述生成风格。如:Happy",
# value="Happy",
# visible=False,
# )
# audio_prompt = gr.Audio(
# label="Audio prompt", type="filepath", visible=False
# )
sdp_ratio = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
)
noise_scale = gr.Slider(
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
)
noise_scale_w = gr.Slider(
minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
)
length_scale = gr.Slider(
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
)
btn = gr.Button("生成音频!", variant="primary")
with gr.Column():
with gr.Accordion("融合文本语义", open=False):
gr.Markdown(
value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
"**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
"效果较不明确,留空即为不使用该功能"
)
style_text = gr.Textbox(label="辅助文本")
style_weight = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
step=0.1,
label="Weight",
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
)
text_output = gr.Textbox(label="状态信息")
audio_output = gr.Audio(label="输出音频")
# explain_image = gr.Image(
# label="参数解释信息",
# show_label=True,
# show_share_button=False,
# show_download_button=False,
# value=os.path.abspath("./img/参数说明.png"),
# )
btn.click(
tts_fn,
inputs=[
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# audio_prompt,
# text_prompt,
# prompt_mode,
style_text,
style_weight,
],
outputs=[text_output, audio_output],
)
# trans.click(
# translate,
# inputs=[text],
# outputs=[text],
# )
# slicer.click(
# tts_split,
# inputs=[
# text,
# speaker,
# sdp_ratio,
# noise_scale,
# noise_scale_w,
# length_scale,
# language,
# opt_cut_by_sent,
# interval_between_para,
# interval_between_sent,
# # audio_prompt,
# # text_prompt,
# style_text,
# style_weight,
# ],
# outputs=[text_output, audio_output],
# )
# formatter.click(
# format_utils,
# inputs=[text, speaker],
# outputs=[language, text],
# )
print("推理页面已开启!")
# webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)