lihongjie
commited on
Commit
·
41b3743
1
Parent(s):
f520907
add gradio gui code
Browse files- .gitattributes +2 -0
- .gitignore +3 -1
- main_api_ax650 +3 -0
- main_api_axcl_aarch64 +3 -0
- run_api_ax650.sh +23 -0
- run_api_axcl_aarch64.sh +24 -0
- scripts/gradio_demo.py +161 -0
- scripts/requirements.txt +1 -0
.gitattributes
CHANGED
|
@@ -111,3 +111,5 @@ token2wav-axmodels/rand_noise_1_80_300.txt filter=lfs diff=lfs merge=lfs -text
|
|
| 111 |
token2wav-axmodels/speech_window_2x8x480.txt filter=lfs diff=lfs merge=lfs -text
|
| 112 |
token2wav-axmodels/hift_p1_50_first.mnn filter=lfs diff=lfs merge=lfs -text
|
| 113 |
main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 111 |
token2wav-axmodels/speech_window_2x8x480.txt filter=lfs diff=lfs merge=lfs -text
|
| 112 |
token2wav-axmodels/hift_p1_50_first.mnn filter=lfs diff=lfs merge=lfs -text
|
| 113 |
main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
|
| 114 |
+
main_api_ax650 filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
main_api_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
output*.wav
|
| 2 |
-
__pycache__/
|
|
|
|
|
|
|
|
|
| 1 |
output*.wav
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.crt
|
| 4 |
+
*.key
|
main_api_ax650
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed4d3c19fb9c44ebbc565dff37e6049debec2223302c1d334cbf895df5b39ab9
|
| 3 |
+
size 9653912
|
main_api_axcl_aarch64
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:308b67412ada48aeec9ed0bf9f5af60426db4f5ca6cd9e3e2a2e8ea108e3b735
|
| 3 |
+
size 4901808
|
run_api_ax650.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LLM_DIR=CosyVoice-BlankEN-Ax650-prefill_512/
|
| 2 |
+
TOKEN2WAV_DIR=token2wav-axmodels/
|
| 3 |
+
|
| 4 |
+
openssl req -newkey rsa:2048 -new -nodes -x509 -days 365 -keyout server.key -out server.crt -subj "/C=CN/ST=Beijing/L=Beijing/O=YourOrg/CN=localhost"
|
| 5 |
+
|
| 6 |
+
rm output*.wav
|
| 7 |
+
./main_api_ax650 \
|
| 8 |
+
--template_filename_axmodel "${LLM_DIR}/qwen2_p128_l%d_together.axmodel" \
|
| 9 |
+
--token2wav_axmodel_dir $TOKEN2WAV_DIR \
|
| 10 |
+
--n_timesteps 10 \
|
| 11 |
+
--axmodel_num 24 \
|
| 12 |
+
--bos 0 --eos 0 \
|
| 13 |
+
--filename_tokenizer_model "http://127.0.0.1:12345" \
|
| 14 |
+
--filename_post_axmodel "${LLM_DIR}/qwen2_post.axmodel" \
|
| 15 |
+
--filename_decoder_axmodel "${LLM_DIR}/llm_decoder.axmodel" \
|
| 16 |
+
--filename_tokens_embed "${LLM_DIR}/model.embed_tokens.weight.bfloat16.bin" \
|
| 17 |
+
--filename_llm_embed "${LLM_DIR}/llm.llm_embedding.float16.bin" \
|
| 18 |
+
--filename_speech_embed "${LLM_DIR}/llm.speech_embedding.float16.bin" \
|
| 19 |
+
--continue 0 \
|
| 20 |
+
--prompt_files prompt_files
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
chmod 777 output*.wav
|
run_api_axcl_aarch64.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LLM_DIR=CosyVoice-BlankEN-Ax650-prefill_512/
|
| 2 |
+
TOKEN2WAV_DIR=token2wav-axmodels/
|
| 3 |
+
|
| 4 |
+
openssl req -newkey rsa:2048 -new -nodes -x509 -days 365 -keyout server.key -out server.crt -subj "/C=CN/ST=Beijing/L=Beijing/O=YourOrg/CN=localhost"
|
| 5 |
+
|
| 6 |
+
rm output*.wav
|
| 7 |
+
./main_api_axcl_aarch64 \
|
| 8 |
+
--template_filename_axmodel "${LLM_DIR}/qwen2_p128_l%d_together.axmodel" \
|
| 9 |
+
--token2wav_axmodel_dir $TOKEN2WAV_DIR \
|
| 10 |
+
--n_timesteps 10 \
|
| 11 |
+
--axmodel_num 24 \
|
| 12 |
+
--bos 0 --eos 0 \
|
| 13 |
+
--filename_tokenizer_model "http://127.0.0.1:12345" \
|
| 14 |
+
--filename_post_axmodel "${LLM_DIR}/qwen2_post.axmodel" \
|
| 15 |
+
--filename_decoder_axmodel "${LLM_DIR}/llm_decoder.axmodel" \
|
| 16 |
+
--filename_tokens_embed "${LLM_DIR}/model.embed_tokens.weight.bfloat16.bin" \
|
| 17 |
+
--filename_llm_embed "${LLM_DIR}/llm.llm_embedding.float16.bin" \
|
| 18 |
+
--filename_speech_embed "${LLM_DIR}/llm.speech_embedding.float16.bin" \
|
| 19 |
+
--continue 0 \
|
| 20 |
+
--devices "0," \
|
| 21 |
+
--prompt_files prompt_files
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
chmod 777 output*.wav
|
scripts/gradio_demo.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import shutil
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import numpy as np
|
| 5 |
+
import requests
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from frontend import CosyVoiceFrontEnd
|
| 11 |
+
import torchaudio
|
| 12 |
+
import logging
|
| 13 |
+
logging.basicConfig(level=logging.WARNING)
|
| 14 |
+
|
| 15 |
+
import subprocess
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
def get_all_local_ips():
|
| 19 |
+
result = subprocess.run(['ip', 'a'], capture_output=True, text=True)
|
| 20 |
+
output = result.stdout
|
| 21 |
+
|
| 22 |
+
# 匹配所有IPv4
|
| 23 |
+
ips = re.findall(r'inet (\d+\.\d+\.\d+\.\d+)', output)
|
| 24 |
+
|
| 25 |
+
# 过滤掉回环地址
|
| 26 |
+
real_ips = [ip for ip in ips if not ip.startswith('127.')]
|
| 27 |
+
|
| 28 |
+
return real_ips
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
TTS_URL = "http://0.0.0.0:12346/tts"
|
| 32 |
+
GET_URL = "http://0.0.0.0:12346/get"
|
| 33 |
+
TIMESTEPS_URL = "http://0.0.0.0:12346/timesteps"
|
| 34 |
+
PROMPT_FILES_URL = "http://0.0.0.0:12346/prompt_files"
|
| 35 |
+
|
| 36 |
+
args = argparse.ArgumentParser()
|
| 37 |
+
args.add_argument('--model_dir', type=str, default="scripts/CosyVoice-BlankEN", help="tokenizer configuration directionary")
|
| 38 |
+
args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
|
| 39 |
+
args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
|
| 40 |
+
args = args.parse_args()
|
| 41 |
+
frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
|
| 42 |
+
args.wetext_dir,
|
| 43 |
+
"frontend-onnx/campplus.onnx",
|
| 44 |
+
"frontend-onnx/speech_tokenizer_v2.onnx",
|
| 45 |
+
f"{args.model_dir}/spk2info.pt",
|
| 46 |
+
"all")
|
| 47 |
+
|
| 48 |
+
def update_audio(audio_input_path, audio_text):
|
| 49 |
+
def load_wav(wav, target_sr):
|
| 50 |
+
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
| 51 |
+
speech = speech.mean(dim=0, keepdim=True)
|
| 52 |
+
if sample_rate != target_sr:
|
| 53 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
| 54 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
| 55 |
+
return speech
|
| 56 |
+
output_dir = './output_temp'
|
| 57 |
+
# clear output_dir
|
| 58 |
+
if os.path.exists(output_dir):
|
| 59 |
+
shutil.rmtree(output_dir)
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
zero_shot_spk_id = ""
|
| 62 |
+
prompt_speech_16k = load_wav(audio_input_path, 16000)
|
| 63 |
+
prompt_text = audio_text
|
| 64 |
+
print("prompt_text",prompt_text)
|
| 65 |
+
model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
|
| 66 |
+
print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
|
| 67 |
+
assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
|
| 68 |
+
for k, v in model_input.items():
|
| 69 |
+
if "_len" in k:
|
| 70 |
+
continue
|
| 71 |
+
shapes = [str(s) for s in v.shape]
|
| 72 |
+
shape_str = "_".join(shapes)
|
| 73 |
+
if v.dtype in (torch.int32, torch.int64):
|
| 74 |
+
np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
|
| 75 |
+
else:
|
| 76 |
+
np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
r = requests.post(PROMPT_FILES_URL, json={"prompt_files": output_dir}, timeout=5)
|
| 80 |
+
if r.status_code != 200:
|
| 81 |
+
return None, "❌ TTS 请求失败"
|
| 82 |
+
except Exception as e:
|
| 83 |
+
return None, f"❌ TTS 请求异常: {e}"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def update_timesteps(timesteps):
|
| 87 |
+
try:
|
| 88 |
+
r = requests.post(TIMESTEPS_URL, json={"timesteps": timesteps}, timeout=5)
|
| 89 |
+
if r.status_code != 200:
|
| 90 |
+
return None, "❌ TTS 请求失败"
|
| 91 |
+
except Exception as e:
|
| 92 |
+
return None, f"❌ TTS 请求异常: {e}"
|
| 93 |
+
|
| 94 |
+
def run_tts(text):
|
| 95 |
+
# Step1: 提交 TTS 请求
|
| 96 |
+
try:
|
| 97 |
+
r = requests.post(TTS_URL, json={"text": text}, timeout=5)
|
| 98 |
+
if r.status_code != 200:
|
| 99 |
+
return None, "❌ TTS 请求失败"
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return None, f"❌ TTS 请求异常: {e}"
|
| 102 |
+
|
| 103 |
+
# Step2: 循环调用 /get 获取进度
|
| 104 |
+
progress = gr.Progress()
|
| 105 |
+
wav_file = None
|
| 106 |
+
for i in range(100): # 最多尝试100次,避免死循环
|
| 107 |
+
time.sleep(0.5)
|
| 108 |
+
try:
|
| 109 |
+
resp = requests.post(GET_URL, data="", timeout=5).json()
|
| 110 |
+
except Exception as e:
|
| 111 |
+
return None, f"❌ GET 请求异常: {e}"
|
| 112 |
+
|
| 113 |
+
if resp.get("b_tts_runing", True):
|
| 114 |
+
progress(i / 100, desc="正在生成语音...")
|
| 115 |
+
else:
|
| 116 |
+
wav_file = resp.get("wav_file")
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
if not wav_file or not os.path.exists(wav_file):
|
| 120 |
+
return None, "❌ 语音文件未生成"
|
| 121 |
+
|
| 122 |
+
return wav_file, "✅ 生成完成"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
with gr.Blocks() as demo:
|
| 126 |
+
gr.Markdown("### 🎙️ AXERA CosyVoice2 Demo")
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column():
|
| 130 |
+
audio_input = gr.Audio(label="输入音频", type="filepath")
|
| 131 |
+
with gr.Column():
|
| 132 |
+
audio_text = gr.Textbox(label="音频文本(自己改一下或者照着念)", value="锄禾日当午,汗滴禾下土。")
|
| 133 |
+
btn_update = gr.Button("更新音源")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
with gr.Row():
|
| 137 |
+
text_input = gr.Textbox(value="琦琦,麻烦你适配一下这个新的模型吧。", label="输入文本")
|
| 138 |
+
with gr.Column():
|
| 139 |
+
timesteps = gr.Slider(minimum=4, maximum=30, value=7, step=1, label="Timesteps")
|
| 140 |
+
run_btn = gr.Button("生成语音")
|
| 141 |
+
|
| 142 |
+
status = gr.Label(label="状态")
|
| 143 |
+
audio_out = gr.Audio(label="生成结果", type="filepath")
|
| 144 |
+
|
| 145 |
+
run_btn.click(fn=run_tts, inputs=[text_input], outputs=[audio_out, status])
|
| 146 |
+
timesteps.change(fn=update_timesteps, inputs=timesteps)
|
| 147 |
+
|
| 148 |
+
btn_update.click(fn=update_audio, inputs=[audio_input, audio_text])
|
| 149 |
+
|
| 150 |
+
ips = get_all_local_ips()
|
| 151 |
+
for ip in ips:
|
| 152 |
+
print(f"* Running on local URL: https://{ip}:7860")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
demo.launch(
|
| 156 |
+
server_name="0.0.0.0",
|
| 157 |
+
server_port=7860,
|
| 158 |
+
ssl_certfile="./server.crt",
|
| 159 |
+
ssl_keyfile="./server.key",
|
| 160 |
+
ssl_verify=False
|
| 161 |
+
)
|
scripts/requirements.txt
CHANGED
|
@@ -38,3 +38,4 @@ transformers>=4.40.1
|
|
| 38 |
uvicorn==0.30.0
|
| 39 |
wetext==0.0.4
|
| 40 |
wget==3.2
|
|
|
|
|
|
| 38 |
uvicorn==0.30.0
|
| 39 |
wetext==0.0.4
|
| 40 |
wget==3.2
|
| 41 |
+
gradio==5.47.1
|