so / app.py
leesenx's picture
Update app.py
91bc463 verified
#!/usr/bin/env python3
import os
import re
import time
import uuid
from datetime import datetime
from pathlib import Path
import gradio as gr
import sherpa_onnx
import soundfile as sf
from model import get_pretrained_model, get_speaker_map, language_to_models
def MyPrint(s):
now = datetime.now()
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
print(f"{date_time}: {s}")
def get_num_speakers(repo_id: str) -> int:
m = re.search(r"\|(\d+)\s*speaker", repo_id)
if m:
return int(m.group(1))
return 1
def get_speaker_choices(repo_id: str) -> list:
n = get_num_speakers(repo_id)
speaker_map = get_speaker_map(repo_id)
if speaker_map:
return [f"{v}" for _, v in sorted(speaker_map.items())]
return [f"说话人 {i}" for i in range(n)]
def extract_sid_from_label(repo_id: str, label: str) -> int:
speaker_map = get_speaker_map(repo_id)
if speaker_map:
for sid, name in speaker_map.items():
if name == label:
return sid
m = re.search(r"说话人\s*(\d+)", label)
if m:
return int(m.group(1))
return 0
title = "# 文字转语音 (TTS)"
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
"""
def update_model_dropdown(language: str):
if language in language_to_models:
choices = language_to_models[language]
speaker_choices = get_speaker_choices(choices[0])
return gr.Dropdown(
choices=choices,
value=choices[0],
interactive=True,
), gr.Dropdown(
choices=speaker_choices,
value=speaker_choices[0],
visible=len(speaker_choices) > 1,
interactive=True,
)
raise ValueError(f"不支持的语言: {language}")
def update_speaker_dropdown(repo_id: str):
speaker_choices = get_speaker_choices(repo_id)
return gr.Dropdown(
choices=speaker_choices,
value=speaker_choices[0],
visible=len(speaker_choices) > 1,
interactive=True,
)
def build_html_output(s: str, style: str = "result_item_success"):
return f"""
<div class='result'>
<div class='result_item {style}'>
{s}
</div>
</div>
"""
def process(language: str, repo_id: str, text: str, speaker: str, speed: float):
max_len = 4000
sid = extract_sid_from_label(repo_id, speaker)
MyPrint(f"输入文本长度 {len(text)}: {text[:max_len]}. 说话人: {speaker}(id={sid}), 语速: {speed}")
if len(text) > max_len:
MyPrint(f"文本过长!{len(text)}")
info = "为保证响应速度,请使用短文本进行测试。如需处理长文本,请在本地运行。"
return None, build_html_output(info)
n = get_num_speakers(repo_id)
if n > 1 and sid >= n:
sid = n - 1
tts = get_pretrained_model(repo_id, speed)
start = time.time()
audio = tts.generate(text, sid=sid)
end = time.time()
if len(audio.samples) == 0:
raise ValueError("语音生成出错,请查看上方错误信息。")
duration = len(audio.samples) / audio.sample_rate
elapsed_seconds = end - start
rtf = elapsed_seconds / duration
info = f"""
音频时长: {duration:.3f} 秒<br/>
处理时间: {elapsed_seconds:.3f} 秒<br/>
实时率(RTF): {rtf:.3f}<br/>
说话人: {speaker}
"""
MyPrint(info)
MyPrint(f"\nrepo_id: {repo_id}\ntext: {text}\nsid: {sid}\nspeed: {speed}")
filename = str(uuid.uuid4())
filename = f"{filename}.wav"
sf.write(
filename,
audio.samples,
samplerate=audio.sample_rate,
subtype="PCM_16",
)
return filename, build_html_output(info)
demo = gr.Blocks(css=css)
with demo:
gr.Markdown(title)
language_choices = list(language_to_models.keys())
language_radio = gr.Radio(
label="语言",
choices=language_choices,
value=language_choices[0],
)
model_dropdown = gr.Dropdown(
choices=language_to_models[language_choices[0]],
label="选择模型",
value=language_to_models[language_choices[0]][0],
)
first_model = language_to_models[language_choices[0]][0]
first_speakers = get_speaker_choices(first_model)
speaker_dropdown = gr.Dropdown(
choices=first_speakers,
value=first_speakers[0],
label="选择说话人",
visible=len(first_speakers) > 1,
interactive=True,
allow_custom_value=True,
)
language_radio.change(
update_model_dropdown,
inputs=language_radio,
outputs=[model_dropdown, speaker_dropdown],
)
model_dropdown.change(
update_speaker_dropdown,
inputs=model_dropdown,
outputs=speaker_dropdown,
)
with gr.Tabs():
with gr.TabItem("输入文本"):
input_text = gr.Textbox(
label="输入文本",
info="请输入要转换为语音的文字",
lines=3,
value="大家好,这是一个文字转语音的测试。",
placeholder="请输入要转换为语音的文字",
)
input_speed = gr.Slider(
minimum=0.1,
maximum=10,
value=1,
step=0.1,
label="语速(越大越快,越小越慢)",
)
input_button = gr.Button("生成语音")
output_audio = gr.Audio(label="生成的语音")
output_info = gr.HTML(label="信息")
input_button.click(
process,
inputs=[
language_radio,
model_dropdown,
input_text,
speaker_dropdown,
input_speed,
],
outputs=[
output_audio,
output_info,
],
)
def download_espeak_ng_data():
os.system(
"""
cd /tmp
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
"""
)
if not Path("/tmp/dict").is_dir():
os.system(
"cd /tmp; curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2; tar xvf dict.tar.bz2"
)
os.system("ls -lh /tmp/dict")
if __name__ == "__main__":
download_espeak_ng_data()
demo.launch()