lihongjie commited on
Commit
41b3743
·
1 Parent(s): f520907

add gradio gui code

Browse files
.gitattributes CHANGED
@@ -111,3 +111,5 @@ token2wav-axmodels/rand_noise_1_80_300.txt filter=lfs diff=lfs merge=lfs -text
111
  token2wav-axmodels/speech_window_2x8x480.txt filter=lfs diff=lfs merge=lfs -text
112
  token2wav-axmodels/hift_p1_50_first.mnn filter=lfs diff=lfs merge=lfs -text
113
  main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
 
 
 
111
  token2wav-axmodels/speech_window_2x8x480.txt filter=lfs diff=lfs merge=lfs -text
112
  token2wav-axmodels/hift_p1_50_first.mnn filter=lfs diff=lfs merge=lfs -text
113
  main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
114
+ main_api_ax650 filter=lfs diff=lfs merge=lfs -text
115
+ main_api_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  output*.wav
2
- __pycache__/
 
 
 
1
  output*.wav
2
+ __pycache__/
3
+ *.crt
4
+ *.key
main_api_ax650 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed4d3c19fb9c44ebbc565dff37e6049debec2223302c1d334cbf895df5b39ab9
3
+ size 9653912
main_api_axcl_aarch64 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:308b67412ada48aeec9ed0bf9f5af60426db4f5ca6cd9e3e2a2e8ea108e3b735
3
+ size 4901808
run_api_ax650.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_DIR=CosyVoice-BlankEN-Ax650-prefill_512/
2
+ TOKEN2WAV_DIR=token2wav-axmodels/
3
+
4
+ openssl req -newkey rsa:2048 -new -nodes -x509 -days 365 -keyout server.key -out server.crt -subj "/C=CN/ST=Beijing/L=Beijing/O=YourOrg/CN=localhost"
5
+
6
+ rm output*.wav
7
+ ./main_api_ax650 \
8
+ --template_filename_axmodel "${LLM_DIR}/qwen2_p128_l%d_together.axmodel" \
9
+ --token2wav_axmodel_dir $TOKEN2WAV_DIR \
10
+ --n_timesteps 10 \
11
+ --axmodel_num 24 \
12
+ --bos 0 --eos 0 \
13
+ --filename_tokenizer_model "http://127.0.0.1:12345" \
14
+ --filename_post_axmodel "${LLM_DIR}/qwen2_post.axmodel" \
15
+ --filename_decoder_axmodel "${LLM_DIR}/llm_decoder.axmodel" \
16
+ --filename_tokens_embed "${LLM_DIR}/model.embed_tokens.weight.bfloat16.bin" \
17
+ --filename_llm_embed "${LLM_DIR}/llm.llm_embedding.float16.bin" \
18
+ --filename_speech_embed "${LLM_DIR}/llm.speech_embedding.float16.bin" \
19
+ --continue 0 \
20
+ --prompt_files prompt_files
21
+
22
+
23
+ chmod 777 output*.wav
run_api_axcl_aarch64.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_DIR=CosyVoice-BlankEN-Ax650-prefill_512/
2
+ TOKEN2WAV_DIR=token2wav-axmodels/
3
+
4
+ openssl req -newkey rsa:2048 -new -nodes -x509 -days 365 -keyout server.key -out server.crt -subj "/C=CN/ST=Beijing/L=Beijing/O=YourOrg/CN=localhost"
5
+
6
+ rm output*.wav
7
+ ./main_api_axcl_aarch64 \
8
+ --template_filename_axmodel "${LLM_DIR}/qwen2_p128_l%d_together.axmodel" \
9
+ --token2wav_axmodel_dir $TOKEN2WAV_DIR \
10
+ --n_timesteps 10 \
11
+ --axmodel_num 24 \
12
+ --bos 0 --eos 0 \
13
+ --filename_tokenizer_model "http://127.0.0.1:12345" \
14
+ --filename_post_axmodel "${LLM_DIR}/qwen2_post.axmodel" \
15
+ --filename_decoder_axmodel "${LLM_DIR}/llm_decoder.axmodel" \
16
+ --filename_tokens_embed "${LLM_DIR}/model.embed_tokens.weight.bfloat16.bin" \
17
+ --filename_llm_embed "${LLM_DIR}/llm.llm_embedding.float16.bin" \
18
+ --filename_speech_embed "${LLM_DIR}/llm.speech_embedding.float16.bin" \
19
+ --continue 0 \
20
+ --devices "0," \
21
+ --prompt_files prompt_files
22
+
23
+
24
+ chmod 777 output*.wav
scripts/gradio_demo.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import shutil
3
+ import gradio as gr
4
+ import numpy as np
5
+ import requests
6
+ import time
7
+ import os
8
+
9
+ import torch
10
+ from frontend import CosyVoiceFrontEnd
11
+ import torchaudio
12
+ import logging
13
+ logging.basicConfig(level=logging.WARNING)
14
+
15
+ import subprocess
16
+ import re
17
+
18
+ def get_all_local_ips():
19
+ result = subprocess.run(['ip', 'a'], capture_output=True, text=True)
20
+ output = result.stdout
21
+
22
+ # 匹配所有IPv4
23
+ ips = re.findall(r'inet (\d+\.\d+\.\d+\.\d+)', output)
24
+
25
+ # 过滤掉回环地址
26
+ real_ips = [ip for ip in ips if not ip.startswith('127.')]
27
+
28
+ return real_ips
29
+
30
+
31
+ TTS_URL = "http://0.0.0.0:12346/tts"
32
+ GET_URL = "http://0.0.0.0:12346/get"
33
+ TIMESTEPS_URL = "http://0.0.0.0:12346/timesteps"
34
+ PROMPT_FILES_URL = "http://0.0.0.0:12346/prompt_files"
35
+
36
+ args = argparse.ArgumentParser()
37
+ args.add_argument('--model_dir', type=str, default="scripts/CosyVoice-BlankEN", help="tokenizer configuration directionary")
38
+ args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
39
+ args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
40
+ args = args.parse_args()
41
+ frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
42
+ args.wetext_dir,
43
+ "frontend-onnx/campplus.onnx",
44
+ "frontend-onnx/speech_tokenizer_v2.onnx",
45
+ f"{args.model_dir}/spk2info.pt",
46
+ "all")
47
+
48
+ def update_audio(audio_input_path, audio_text):
49
+ def load_wav(wav, target_sr):
50
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
51
+ speech = speech.mean(dim=0, keepdim=True)
52
+ if sample_rate != target_sr:
53
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
54
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
55
+ return speech
56
+ output_dir = './output_temp'
57
+ # clear output_dir
58
+ if os.path.exists(output_dir):
59
+ shutil.rmtree(output_dir)
60
+ os.makedirs(output_dir, exist_ok=True)
61
+ zero_shot_spk_id = ""
62
+ prompt_speech_16k = load_wav(audio_input_path, 16000)
63
+ prompt_text = audio_text
64
+ print("prompt_text",prompt_text)
65
+ model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
66
+ print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
67
+ assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
68
+ for k, v in model_input.items():
69
+ if "_len" in k:
70
+ continue
71
+ shapes = [str(s) for s in v.shape]
72
+ shape_str = "_".join(shapes)
73
+ if v.dtype in (torch.int32, torch.int64):
74
+ np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
75
+ else:
76
+ np.savetxt(f"{output_dir}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")
77
+
78
+ try:
79
+ r = requests.post(PROMPT_FILES_URL, json={"prompt_files": output_dir}, timeout=5)
80
+ if r.status_code != 200:
81
+ return None, "❌ TTS 请求失败"
82
+ except Exception as e:
83
+ return None, f"❌ TTS 请求异常: {e}"
84
+
85
+
86
+ def update_timesteps(timesteps):
87
+ try:
88
+ r = requests.post(TIMESTEPS_URL, json={"timesteps": timesteps}, timeout=5)
89
+ if r.status_code != 200:
90
+ return None, "❌ TTS 请求失败"
91
+ except Exception as e:
92
+ return None, f"❌ TTS 请求异常: {e}"
93
+
94
+ def run_tts(text):
95
+ # Step1: 提交 TTS 请求
96
+ try:
97
+ r = requests.post(TTS_URL, json={"text": text}, timeout=5)
98
+ if r.status_code != 200:
99
+ return None, "❌ TTS 请求失败"
100
+ except Exception as e:
101
+ return None, f"❌ TTS 请求异常: {e}"
102
+
103
+ # Step2: 循环调用 /get 获取进度
104
+ progress = gr.Progress()
105
+ wav_file = None
106
+ for i in range(100): # 最多尝试100次,避免死循环
107
+ time.sleep(0.5)
108
+ try:
109
+ resp = requests.post(GET_URL, data="", timeout=5).json()
110
+ except Exception as e:
111
+ return None, f"❌ GET 请求异常: {e}"
112
+
113
+ if resp.get("b_tts_runing", True):
114
+ progress(i / 100, desc="正在生成语音...")
115
+ else:
116
+ wav_file = resp.get("wav_file")
117
+ break
118
+
119
+ if not wav_file or not os.path.exists(wav_file):
120
+ return None, "❌ 语音文件未生成"
121
+
122
+ return wav_file, "✅ 生成完成"
123
+
124
+
125
+ with gr.Blocks() as demo:
126
+ gr.Markdown("### 🎙️ AXERA CosyVoice2 Demo")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ audio_input = gr.Audio(label="输入音频", type="filepath")
131
+ with gr.Column():
132
+ audio_text = gr.Textbox(label="音频文本(自己改一下或者照着念)", value="锄禾日当午,汗滴禾下土。")
133
+ btn_update = gr.Button("更新音源")
134
+
135
+
136
+ with gr.Row():
137
+ text_input = gr.Textbox(value="琦琦,麻烦你适配一下这个新的模型吧。", label="输入文本")
138
+ with gr.Column():
139
+ timesteps = gr.Slider(minimum=4, maximum=30, value=7, step=1, label="Timesteps")
140
+ run_btn = gr.Button("生成语音")
141
+
142
+ status = gr.Label(label="状态")
143
+ audio_out = gr.Audio(label="生成结果", type="filepath")
144
+
145
+ run_btn.click(fn=run_tts, inputs=[text_input], outputs=[audio_out, status])
146
+ timesteps.change(fn=update_timesteps, inputs=timesteps)
147
+
148
+ btn_update.click(fn=update_audio, inputs=[audio_input, audio_text])
149
+
150
+ ips = get_all_local_ips()
151
+ for ip in ips:
152
+ print(f"* Running on local URL: https://{ip}:7860")
153
+
154
+
155
+ demo.launch(
156
+ server_name="0.0.0.0",
157
+ server_port=7860,
158
+ ssl_certfile="./server.crt",
159
+ ssl_keyfile="./server.key",
160
+ ssl_verify=False
161
+ )
scripts/requirements.txt CHANGED
@@ -38,3 +38,4 @@ transformers>=4.40.1
38
  uvicorn==0.30.0
39
  wetext==0.0.4
40
  wget==3.2
 
 
38
  uvicorn==0.30.0
39
  wetext==0.0.4
40
  wget==3.2
41
+ gradio==5.47.1