CosyVoice2 / scripts /gradio_demo.py
lihongjie
replace mnn with onnxruntime
26bfbe8
import argparse
import shutil
import gradio as gr
import numpy as np
import requests
import time
import os
import torch
from frontend import CosyVoiceFrontEnd
import torchaudio
import logging
logging.basicConfig(level=logging.WARNING)
import subprocess
import re
def get_all_local_ips():
result = subprocess.run(['ip', 'a'], capture_output=True, text=True)
output = result.stdout
# 匹配所有IPv4
ips = re.findall(r'inet (\d+\.\d+\.\d+\.\d+)', output)
# 过滤掉回环地址
real_ips = [ip for ip in ips if not ip.startswith('127.')]
return real_ips
TTS_URL = "http://0.0.0.0:12346/tts"
GET_URL = "http://0.0.0.0:12346/get"
TIMESTEPS_URL = "http://0.0.0.0:12346/timesteps"
PROMPT_FILES_URL = "http://0.0.0.0:12346/prompt_files"
args = argparse.ArgumentParser()
args.add_argument('--model_dir', type=str, default="scripts/CosyVoice-BlankEN", help="tokenizer configuration directionary")
args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
args = args.parse_args()
frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
args.wetext_dir,
"frontend-onnx/campplus.onnx",
"frontend-onnx/speech_tokenizer_v2.onnx",
f"{args.model_dir}/spk2info.pt",
"all")
def update_audio(audio_input_path, audio_text):
def load_wav(wav, target_sr):
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
speech = speech.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech
output_dir = './output_temp'
# clear output_dir
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
zero_shot_spk_id = ""
prompt_speech_16k = load_wav(audio_input_path, 16000)
prompt_text = audio_text
print("prompt_text",prompt_text)
model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
for k, v in model_input.items():
if "_len" in k:
continue
shapes = [str(s) for s in v.shape]
shape_str = "_".join(shapes)
if v.dtype in (torch.int32, torch.int64):
np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
else:
np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")
try:
r = requests.post(PROMPT_FILES_URL, json={"prompt_files": output_dir}, timeout=5)
if r.status_code != 200:
return None, "❌ TTS 请求失败"
except Exception as e:
return None, f"❌ TTS 请求异常: {e}"
def update_timesteps(timesteps):
try:
r = requests.post(TIMESTEPS_URL, json={"timesteps": timesteps}, timeout=5)
if r.status_code != 200:
return None, "❌ TTS 请求失败"
except Exception as e:
return None, f"❌ TTS 请求异常: {e}"
def run_tts(text):
# Step1: 提交 TTS 请求
try:
r = requests.post(TTS_URL, json={"text": text}, timeout=5)
if r.status_code != 200:
return None, "❌ TTS 请求失败"
except Exception as e:
return None, f"❌ TTS 请求异常: {e}"
# Step2: 循环调用 /get 获取进度
progress = gr.Progress()
wav_file = None
for i in range(100): # 最多尝试100次,避免死循环
time.sleep(0.5)
try:
resp = requests.post(GET_URL, data="", timeout=5).json()
except Exception as e:
return None, f"❌ GET 请求异常: {e}"
if resp.get("b_tts_runing", True):
progress(i / 100, desc="正在生成语音...")
else:
wav_file = resp.get("wav_file")
break
if not wav_file or not os.path.exists(wav_file):
return None, "❌ 语音文件未生成"
return wav_file, "✅ 生成完成"
with gr.Blocks() as demo:
gr.Markdown("### 🎙️ AXERA CosyVoice2 Demo")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="输入音频", type="filepath")
with gr.Column():
audio_text = gr.Textbox(label="音频文本(自己改一下或者照着念)", value="锄禾日当午,汗滴禾下土。")
btn_update = gr.Button("更新音源")
with gr.Row():
text_input = gr.Textbox(value="琦琦,麻烦你适配一下这个新的模型吧。", label="输入文本")
with gr.Column():
timesteps = gr.Slider(minimum=4, maximum=30, value=10, step=1, label="Timesteps")
run_btn = gr.Button("生成语音")
status = gr.Label(label="状态")
audio_out = gr.Audio(label="生成结果", type="filepath")
run_btn.click(fn=run_tts, inputs=[text_input], outputs=[audio_out, status])
timesteps.change(fn=update_timesteps, inputs=timesteps)
btn_update.click(fn=update_audio, inputs=[audio_input, audio_text])
ips = get_all_local_ips()
for ip in ips:
print(f"* Running on local URL: https://{ip}:7860")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssl_certfile="./server.crt",
ssl_keyfile="./server.key",
ssl_verify=False
)