Spaces:
Runtime error
Runtime error
File size: 9,762 Bytes
a3c7d09 f175c06 a3c7d09 4d57ae2 a3c7d09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
# flake8: noqa: E402
import gc
import os
import logging
import re_matching
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
import torch
import utils
from infer import infer, latest_version, get_net_g
import gradio as gr
# import webbrowser
import numpy as np
from config import config
# multithreading
torch.set_num_threads(os.cpu_count())
torch.set_num_interop_threads(os.cpu_count())
net_g = None
device = config.device
if device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def free_up_memory():
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate_audio(
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
# reference_audio=reference_audio,
emotion=None,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language="ZH",
hps=hps,
net_g=net_g,
device=device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# reference_audio,
# emotion,
style_text=None,
style_weight=0,
):
audio_list = []
audio_list.extend(
generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_fn(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# reference_audio,
# emotion,
# prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
# if prompt_mode == "Audio prompt":
# if reference_audio == None:
# return ("Invalid audio prompt", None)
# else:
# reference_audio = load_audio(reference_audio)[1]
# else:
# reference_audio = None
audio_list = process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# reference_audio,
# emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (hps.data.sampling_rate, audio_concat)
if __name__ == "__main__":
if config.webui_config.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
hps = utils.get_hparams_from_file(config.webui_config.config_path)
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = ["ZH", "JP", "EN", "mix", "auto"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="输入文本内容",
)
# trans = gr.Button("中翻日", variant="primary")
# slicer = gr.Button("快速切分", variant="primary")
# formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
speaker = gr.Dropdown(
choices=speakers, value=speakers[0], label="Speaker"
)
# _ = gr.Markdown(
# value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
# visible=False,
# )
# prompt_mode = gr.Radio(
# ["Text prompt", "Audio prompt"],
# label="Prompt Mode",
# value="Text prompt",
# visible=False,
# )
# text_prompt = gr.Textbox(
# label="Text prompt",
# placeholder="用文字描述生成风格。如:Happy",
# value="Happy",
# visible=False,
# )
# audio_prompt = gr.Audio(
# label="Audio prompt", type="filepath", visible=False
# )
sdp_ratio = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
)
noise_scale = gr.Slider(
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
)
noise_scale_w = gr.Slider(
minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
)
length_scale = gr.Slider(
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
)
btn = gr.Button("生成音频!", variant="primary")
with gr.Column():
with gr.Accordion("融合文本语义", open=False):
gr.Markdown(
value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
"**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
"效果较不明确,留空即为不使用该功能"
)
style_text = gr.Textbox(label="辅助文本")
style_weight = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
step=0.1,
label="Weight",
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
)
text_output = gr.Textbox(label="状态信息")
audio_output = gr.Audio(label="输出音频")
# explain_image = gr.Image(
# label="参数解释信息",
# show_label=True,
# show_share_button=False,
# show_download_button=False,
# value=os.path.abspath("./img/参数说明.png"),
# )
btn.click(
tts_fn,
inputs=[
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
# language,
# audio_prompt,
# text_prompt,
# prompt_mode,
style_text,
style_weight,
],
outputs=[text_output, audio_output],
)
# trans.click(
# translate,
# inputs=[text],
# outputs=[text],
# )
# slicer.click(
# tts_split,
# inputs=[
# text,
# speaker,
# sdp_ratio,
# noise_scale,
# noise_scale_w,
# length_scale,
# language,
# opt_cut_by_sent,
# interval_between_para,
# interval_between_sent,
# # audio_prompt,
# # text_prompt,
# style_text,
# style_weight,
# ],
# outputs=[text_output, audio_output],
# )
# formatter.click(
# format_utils,
# inputs=[text, speaker],
# outputs=[language, text],
# )
print("推理页面已开启!")
# webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|