Spaces:
Running
on
Zero
Running
on
Zero
开始部署
Browse files- app.py +163 -147
- wenet/osum_echat/init_llmasr.py +1 -1
app.py
CHANGED
|
@@ -50,12 +50,15 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
|
|
| 50 |
device = torch.device("cuda")
|
| 51 |
print("开始加载模型 A...")
|
| 52 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
|
|
|
| 53 |
|
| 54 |
print("\n开始加载模型 B...")
|
| 55 |
if CHECKPOINT_PATH_B is not None:
|
| 56 |
model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
|
|
|
|
| 57 |
else:
|
| 58 |
model_b, tokenizer_b = None, None
|
|
|
|
| 59 |
loaded_models = {
|
| 60 |
NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
|
| 61 |
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
|
|
@@ -65,6 +68,7 @@ loaded_models = {
|
|
| 65 |
print("\n所有模型已加载完毕。")
|
| 66 |
|
| 67 |
cosyvoice = CosyVoice(cosyvoice_model_path)
|
|
|
|
| 68 |
|
| 69 |
# 将图片转换为 Base64
|
| 70 |
with open("./tts/assert/实验室.png", "rb") as image_file:
|
|
@@ -109,144 +113,173 @@ for item in prompt_audio_choices:
|
|
| 109 |
|
| 110 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
@spaces.GPU
|
| 114 |
-
def do_s2t(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
| 115 |
-
model.eval().cuda()
|
| 116 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 117 |
-
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 118 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 119 |
-
start_time = time.time()
|
| 120 |
-
res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt, cache_implementation="static")[0]
|
| 121 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 122 |
-
end_time = time.time()
|
| 123 |
-
print(f"S2T 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 124 |
-
return res_text
|
| 125 |
-
|
| 126 |
-
@spaces.GPU
|
| 127 |
-
def do_s2t4chat(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
| 128 |
-
model.eval().cuda()
|
| 129 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 130 |
-
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 131 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 132 |
-
start_time = time.time()
|
| 133 |
-
res_text = model.generate4chat(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
|
| 134 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 135 |
-
end_time = time.time()
|
| 136 |
-
print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 137 |
-
return res_text
|
| 138 |
-
@spaces.GPU
|
| 139 |
-
def do_s2t4chat_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
| 140 |
-
model.eval().cuda()
|
| 141 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 142 |
-
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 143 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 144 |
-
start_time = time.time()
|
| 145 |
-
res_text = model.generate4chat_think(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
|
| 146 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 147 |
-
end_time = time.time()
|
| 148 |
-
print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 149 |
-
return res_text
|
| 150 |
-
|
| 151 |
-
@spaces.GPU
|
| 152 |
-
def do_t2s(model, input_prompt, text_for_tts, profile=False): # 增加 model 参数
|
| 153 |
-
model.eval().cuda()
|
| 154 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 155 |
-
start_time = time.time()
|
| 156 |
-
res_tensor = model.generate_tts(device=device, text=text_for_tts, )[0]
|
| 157 |
-
res_token_list = res_tensor.tolist()
|
| 158 |
-
res_text = res_token_list[:-1]
|
| 159 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 160 |
-
end_time = time.time()
|
| 161 |
-
print(f"T2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 162 |
-
return res_text
|
| 163 |
|
| 164 |
@spaces.GPU
|
| 165 |
-
def
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
print(f
|
| 170 |
-
res_text = model.generate_text2text(device=device, text=question_txt)[0]
|
| 171 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 172 |
-
end_time = time.time()
|
| 173 |
-
print(f"T2T 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 174 |
-
return res_text
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 181 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 182 |
-
start_time = time.time()
|
| 183 |
-
output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
|
| 184 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 185 |
-
end_time = time.time()
|
| 186 |
-
print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 187 |
-
return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
| 188 |
|
| 189 |
-
|
| 190 |
-
def do_s2s_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
|
| 191 |
-
model.eval().cuda()
|
| 192 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 193 |
-
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 194 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 195 |
start_time = time.time()
|
| 196 |
-
|
| 197 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 198 |
-
end_time = time.time()
|
| 199 |
-
print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
|
| 200 |
-
return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
| 201 |
-
|
| 202 |
-
@spaces.GPU
|
| 203 |
-
def get_wav_from_token_list(input_list, prompt_speech):
|
| 204 |
-
cosyvoice.eval().cuda()
|
| 205 |
-
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 206 |
-
wav_path = f"./tmp/{time_str}.wav"
|
| 207 |
-
return token_list2wav(input_list, prompt_speech, wav_path, cosyvoice)
|
| 208 |
-
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
if input_prompt.endswith("_TTS"):
|
| 218 |
-
text_for_tts = input_prompt.replace("_TTS", "")
|
| 219 |
-
prompt = "恳请将如下文本转换为其对应的语音token,力求生成最为流畅、自然的语音。"
|
| 220 |
-
res_text = do_t2s(model, prompt, text_for_tts)
|
| 221 |
-
elif input_prompt.endswith("_self_prompt"):
|
| 222 |
-
prompt = input_prompt.replace("_self_prompt", "")
|
| 223 |
-
res_text = do_s2t(model, input_wav_path, prompt)
|
| 224 |
-
elif input_prompt.endswith("_T2T"):
|
| 225 |
-
question_txt = input_prompt.replace("_T2T", "")
|
| 226 |
-
res_text = do_t2t(model, question_txt)
|
| 227 |
-
elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
|
| 228 |
-
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
|
| 229 |
-
"s2s_no_think"]:
|
| 230 |
-
res_text = do_s2s(model, input_wav_path, input_prompt)
|
| 231 |
-
elif input_prompt == "THINK":
|
| 232 |
-
res_text = do_s2s_think(model, input_wav_path, input_prompt)
|
| 233 |
-
elif input_prompt == "s2t_no_think":
|
| 234 |
-
res_text = do_s2t4chat(model, input_wav_path, input_prompt)
|
| 235 |
-
elif input_prompt == "s2t_think":
|
| 236 |
-
res_text = do_s2t4chat_think(model, input_wav_path, input_prompt)
|
| 237 |
-
else:
|
| 238 |
-
res_text = do_s2t(model, input_wav_path, input_prompt)
|
| 239 |
-
res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
|
| 240 |
-
"<adult>")
|
| 241 |
-
|
| 242 |
-
print("识别结果为:", res_text)
|
| 243 |
-
return res_text
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def do_decode(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
|
| 247 |
-
print(f'使用模型进行推理: input_wav_path={input_wav_path}, input_prompt={input_prompt}')
|
| 248 |
-
output_res = true_decode_fuc(model, tokenizer, input_wav_path, input_prompt)
|
| 249 |
-
return output_res
|
| 250 |
|
| 251 |
|
| 252 |
def save_to_jsonl(if_correct, wav, prompt, res):
|
|
@@ -351,24 +384,7 @@ with gr.Blocks() as demo:
|
|
| 351 |
else:
|
| 352 |
input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
|
| 353 |
|
| 354 |
-
|
| 355 |
-
output_res = do_decode(model_to_use, tokenizer_to_use, input_wav_path, input_prompt)
|
| 356 |
-
|
| 357 |
-
# 4. 处理输出 (逻辑不变)
|
| 358 |
-
wav_path_output = input_wav_path
|
| 359 |
-
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
|
| 360 |
-
if isinstance(output_res, list): # TTS case
|
| 361 |
-
wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
|
| 362 |
-
output_res = "生成的token: " + str(output_res)
|
| 363 |
-
elif isinstance(output_res, str) and "|" in output_res: # S2S case
|
| 364 |
-
try:
|
| 365 |
-
text_res, token_list_str = output_res.split("|")
|
| 366 |
-
token_list = json.loads(token_list_str)
|
| 367 |
-
wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
|
| 368 |
-
output_res = text_res
|
| 369 |
-
except (ValueError, json.JSONDecodeError) as e:
|
| 370 |
-
print(f"处理S2S输出时出错: {e}")
|
| 371 |
-
output_res = f"错误:无法解析模型输出 - {output_res}"
|
| 372 |
|
| 373 |
return output_res, wav_path_output
|
| 374 |
|
|
|
|
| 50 |
device = torch.device("cuda")
|
| 51 |
print("开始加载模型 A...")
|
| 52 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
| 53 |
+
model_a.eval().cuda()
|
| 54 |
|
| 55 |
print("\n开始加载模型 B...")
|
| 56 |
if CHECKPOINT_PATH_B is not None:
|
| 57 |
model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
|
| 58 |
+
model_b.eval().cuda()
|
| 59 |
else:
|
| 60 |
model_b, tokenizer_b = None, None
|
| 61 |
+
|
| 62 |
loaded_models = {
|
| 63 |
NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
|
| 64 |
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
|
|
|
|
| 68 |
print("\n所有模型已加载完毕。")
|
| 69 |
|
| 70 |
cosyvoice = CosyVoice(cosyvoice_model_path)
|
| 71 |
+
cosyvoice.eval().cuda()
|
| 72 |
|
| 73 |
# 将图片转换为 Base64
|
| 74 |
with open("./tts/assert/实验室.png", "rb") as image_file:
|
|
|
|
| 113 |
|
| 114 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
| 115 |
|
| 116 |
+
import time
|
| 117 |
+
import datetime
|
| 118 |
+
import torch
|
| 119 |
+
from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
@spaces.GPU
|
| 123 |
+
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
|
| 124 |
+
"""
|
| 125 |
+
合并所有推理逻辑的单个函数,处理所有任务类型
|
| 126 |
+
"""
|
| 127 |
+
print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# 检查音频输入合法性
|
| 130 |
+
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
|
| 131 |
+
print("音频信息未输入,且不是T2S或T2T任务")
|
| 132 |
+
return "错误:需要音频输入"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# 通用初始化:模型设备设置
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
start_time = time.time()
|
| 136 |
+
res_text = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
try:
|
| 139 |
+
# 1. 处理TTS任务
|
| 140 |
+
if input_prompt.endswith("_TTS"):
|
| 141 |
+
text_for_tts = input_prompt.replace("_TTS", "")
|
| 142 |
+
# T2S推理逻辑
|
| 143 |
+
res_tensor = model.generate_tts(device=device, text=text_for_tts)[0]
|
| 144 |
+
res_token_list = res_tensor.tolist()
|
| 145 |
+
res_text = res_token_list[:-1]
|
| 146 |
+
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 147 |
+
|
| 148 |
+
# 2. 处理自定义提示任务
|
| 149 |
+
elif input_prompt.endswith("_self_prompt"):
|
| 150 |
+
prompt = input_prompt.replace("_self_prompt", "")
|
| 151 |
+
# S2T推理逻辑
|
| 152 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 153 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 154 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 155 |
+
res_text = model.generate(
|
| 156 |
+
wavs=feat,
|
| 157 |
+
wavs_len=feat_lens,
|
| 158 |
+
prompt=prompt,
|
| 159 |
+
cache_implementation="static"
|
| 160 |
+
)[0]
|
| 161 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 162 |
+
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 163 |
+
|
| 164 |
+
# 3. 处理T2T任务
|
| 165 |
+
elif input_prompt.endswith("_T2T"):
|
| 166 |
+
question_txt = input_prompt.replace("_T2T", "")
|
| 167 |
+
# T2T推理逻辑
|
| 168 |
+
print(f'开始t2t推理, question_txt: {question_txt}')
|
| 169 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 170 |
+
res_text = model.generate_text2text(
|
| 171 |
+
device=device,
|
| 172 |
+
text=question_txt
|
| 173 |
+
)[0]
|
| 174 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 175 |
+
print(f"T2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 176 |
+
|
| 177 |
+
# 4. 处理S2S无思考任务
|
| 178 |
+
elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
|
| 179 |
+
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
|
| 180 |
+
"s2s_no_think"]:
|
| 181 |
+
# S2S推理逻辑
|
| 182 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 183 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 184 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 185 |
+
output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(
|
| 186 |
+
wavs=feat,
|
| 187 |
+
wavs_len=feat_lens,
|
| 188 |
+
)
|
| 189 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 190 |
+
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
| 191 |
+
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 192 |
+
|
| 193 |
+
# 5. 处理S2S有思考任务
|
| 194 |
+
elif input_prompt == "THINK":
|
| 195 |
+
# S2S带思考推理逻辑
|
| 196 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 197 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 198 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 199 |
+
output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(
|
| 200 |
+
wavs=feat,
|
| 201 |
+
wavs_len=feat_lens,
|
| 202 |
+
)
|
| 203 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 204 |
+
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
|
| 205 |
+
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 206 |
+
|
| 207 |
+
# 6. 处理S2T4Chat无思考任务
|
| 208 |
+
elif input_prompt == "s2t_no_think":
|
| 209 |
+
# S2T4Chat推理逻辑
|
| 210 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 211 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 212 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 213 |
+
res_text = model.generate4chat(
|
| 214 |
+
wavs=feat,
|
| 215 |
+
wavs_len=feat_lens,
|
| 216 |
+
cache_implementation="static"
|
| 217 |
+
)[0]
|
| 218 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 219 |
+
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 220 |
+
|
| 221 |
+
# 7. 处理S2T4Chat有思考任务
|
| 222 |
+
elif input_prompt == "s2t_think":
|
| 223 |
+
# S2T4Chat带思考推理逻辑
|
| 224 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 225 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 226 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 227 |
+
res_text = model.generate4chat_think(
|
| 228 |
+
wavs=feat,
|
| 229 |
+
wavs_len=feat_lens,
|
| 230 |
+
cache_implementation="static"
|
| 231 |
+
)[0]
|
| 232 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 233 |
+
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 234 |
+
|
| 235 |
+
# 8. 处理默认S2T任务
|
| 236 |
+
else:
|
| 237 |
+
# 默认S2T推理逻辑
|
| 238 |
+
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 239 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 240 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 241 |
+
res_text = model.generate(
|
| 242 |
+
wavs=feat,
|
| 243 |
+
wavs_len=feat_lens,
|
| 244 |
+
prompt=input_prompt,
|
| 245 |
+
cache_implementation="static"
|
| 246 |
+
)[0]
|
| 247 |
+
if is_npu: torch_npu.npu.synchronize()
|
| 248 |
+
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 249 |
+
# 替换特定标签
|
| 250 |
+
res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
|
| 251 |
+
"<adult>")
|
| 252 |
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"推理过程出错: {str(e)}")
|
| 255 |
+
return f"错误:{str(e)}"
|
| 256 |
+
|
| 257 |
+
output_res =res_text
|
| 258 |
+
# 4. 处理输出 (逻辑不变)
|
| 259 |
+
wav_path_output = input_wav_path
|
| 260 |
+
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
|
| 261 |
+
if isinstance(output_res, list): # TTS case
|
| 262 |
+
cosyvoice.eval()
|
| 263 |
+
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 264 |
+
wav_path = f"./tmp/{time_str}.wav"
|
| 265 |
+
wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
|
| 266 |
+
# wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
|
| 267 |
+
output_res = "生成的token: " + str(output_res)
|
| 268 |
+
elif isinstance(output_res, str) and "|" in output_res: # S2S case
|
| 269 |
+
try:
|
| 270 |
+
text_res, token_list_str = output_res.split("|")
|
| 271 |
+
token_list = json.loads(token_list_str)
|
| 272 |
+
cosyvoice.eval()
|
| 273 |
+
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 274 |
+
wav_path = f"./tmp/{time_str}.wav"
|
| 275 |
+
wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
|
| 276 |
+
# wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
|
| 277 |
+
output_res = text_res
|
| 278 |
+
except (ValueError, json.JSONDecodeError) as e:
|
| 279 |
+
print(f"处理S2S输出时出错: {e}")
|
| 280 |
+
output_res = f"错误:无法解析模型输出 - {output_res}"
|
| 281 |
+
return output_res, wav_path_output
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
def save_to_jsonl(if_correct, wav, prompt, res):
|
|
|
|
| 384 |
else:
|
| 385 |
input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
|
| 386 |
|
| 387 |
+
output_res, wav_path_output = true_decode_fuc(model_to_use, tokenizer_to_use, input_wav_path, input_prompt,task_choice ,prompt_speech_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
return output_res, wav_path_output
|
| 390 |
|
wenet/osum_echat/init_llmasr.py
CHANGED
|
@@ -121,7 +121,7 @@ def init_llmasr(args, configs, is_inference=False):
|
|
| 121 |
elif fire_module == "link_and_lora":
|
| 122 |
if k.startswith("encoder"):
|
| 123 |
p.requires_grad = False
|
| 124 |
-
logging.info(f"{k} {p.requires_grad} {p.shape} {p.dtype}")
|
| 125 |
logging.info('OSUM-EChat:冻结完毕')
|
| 126 |
logging.info(configs)
|
| 127 |
|
|
|
|
| 121 |
elif fire_module == "link_and_lora":
|
| 122 |
if k.startswith("encoder"):
|
| 123 |
p.requires_grad = False
|
| 124 |
+
# logging.info(f"{k} {p.requires_grad} {p.shape} {p.dtype}")
|
| 125 |
logging.info('OSUM-EChat:冻结完毕')
|
| 126 |
logging.info(configs)
|
| 127 |
|