xlgeng commited on
Commit
6f7f7cd
·
1 Parent(s): 2b52af0

开始部署

Browse files
Files changed (2) hide show
  1. app.py +163 -147
  2. wenet/osum_echat/init_llmasr.py +1 -1
app.py CHANGED
@@ -50,12 +50,15 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
50
  device = torch.device("cuda")
51
  print("开始加载模型 A...")
52
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
 
53
 
54
  print("\n开始加载模型 B...")
55
  if CHECKPOINT_PATH_B is not None:
56
  model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
 
57
  else:
58
  model_b, tokenizer_b = None, None
 
59
  loaded_models = {
60
  NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
61
  NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
@@ -65,6 +68,7 @@ loaded_models = {
65
  print("\n所有模型已加载完毕。")
66
 
67
  cosyvoice = CosyVoice(cosyvoice_model_path)
 
68
 
69
  # 将图片转换为 Base64
70
  with open("./tts/assert/实验室.png", "rb") as image_file:
@@ -109,144 +113,173 @@ for item in prompt_audio_choices:
109
 
110
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
111
 
 
 
 
 
112
 
113
- @spaces.GPU
114
- def do_s2t(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
115
- model.eval().cuda()
116
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
117
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
118
- if is_npu: torch_npu.npu.synchronize()
119
- start_time = time.time()
120
- res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt, cache_implementation="static")[0]
121
- if is_npu: torch_npu.npu.synchronize()
122
- end_time = time.time()
123
- print(f"S2T 推理消耗时间: {end_time - start_time:.2f} 秒")
124
- return res_text
125
-
126
- @spaces.GPU
127
- def do_s2t4chat(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
128
- model.eval().cuda()
129
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
130
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
131
- if is_npu: torch_npu.npu.synchronize()
132
- start_time = time.time()
133
- res_text = model.generate4chat(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
134
- if is_npu: torch_npu.npu.synchronize()
135
- end_time = time.time()
136
- print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
137
- return res_text
138
- @spaces.GPU
139
- def do_s2t4chat_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
140
- model.eval().cuda()
141
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
142
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
143
- if is_npu: torch_npu.npu.synchronize()
144
- start_time = time.time()
145
- res_text = model.generate4chat_think(wavs=feat, wavs_len=feat_lens, cache_implementation="static")[0]
146
- if is_npu: torch_npu.npu.synchronize()
147
- end_time = time.time()
148
- print(f"S2T4Chat 推理消耗时间: {end_time - start_time:.2f} 秒")
149
- return res_text
150
-
151
- @spaces.GPU
152
- def do_t2s(model, input_prompt, text_for_tts, profile=False): # 增加 model 参数
153
- model.eval().cuda()
154
- if is_npu: torch_npu.npu.synchronize()
155
- start_time = time.time()
156
- res_tensor = model.generate_tts(device=device, text=text_for_tts, )[0]
157
- res_token_list = res_tensor.tolist()
158
- res_text = res_token_list[:-1]
159
- if is_npu: torch_npu.npu.synchronize()
160
- end_time = time.time()
161
- print(f"T2S 推理消耗时间: {end_time - start_time:.2f} 秒")
162
- return res_text
163
 
164
  @spaces.GPU
165
- def do_t2t(model, question_txt, profile=False): # 增加 model 参数
166
- model.eval().cuda()
167
- if is_npu: torch_npu.npu.synchronize()
168
- start_time = time.time()
169
- print(f'开始t2t推理, question_txt: {question_txt}')
170
- res_text = model.generate_text2text(device=device, text=question_txt)[0]
171
- if is_npu: torch_npu.npu.synchronize()
172
- end_time = time.time()
173
- print(f"T2T 推理消耗时间: {end_time - start_time:.2f} 秒")
174
- return res_text
175
 
176
- @spaces.GPU
177
- def do_s2s(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
178
- model.eval().cuda()
179
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
180
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
181
- if is_npu: torch_npu.npu.synchronize()
182
- start_time = time.time()
183
- output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
184
- if is_npu: torch_npu.npu.synchronize()
185
- end_time = time.time()
186
- print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
187
- return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
188
 
189
- @spaces.GPU
190
- def do_s2s_think(model, input_wav_path, input_prompt, profile=False): # 增加 model 参数
191
- model.eval().cuda()
192
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
193
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
194
- if is_npu: torch_npu.npu.synchronize()
195
  start_time = time.time()
196
- output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(wavs=feat, wavs_len=feat_lens,)
197
- if is_npu: torch_npu.npu.synchronize()
198
- end_time = time.time()
199
- print(f"S2S 推理消耗时间: {end_time - start_time:.2f} 秒")
200
- return f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
201
-
202
- @spaces.GPU
203
- def get_wav_from_token_list(input_list, prompt_speech):
204
- cosyvoice.eval().cuda()
205
- time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
206
- wav_path = f"./tmp/{time_str}.wav"
207
- return token_list2wav(input_list, prompt_speech, wav_path, cosyvoice)
208
-
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
212
- print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
213
- if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
214
- print("音频信息未输入,且不是T2S或T2T任务")
215
- return "错误:需要音频输入"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- if input_prompt.endswith("_TTS"):
218
- text_for_tts = input_prompt.replace("_TTS", "")
219
- prompt = "恳请将如下文本转换为其对应的语音token,力求生成最为流畅、自然的语音。"
220
- res_text = do_t2s(model, prompt, text_for_tts)
221
- elif input_prompt.endswith("_self_prompt"):
222
- prompt = input_prompt.replace("_self_prompt", "")
223
- res_text = do_s2t(model, input_wav_path, prompt)
224
- elif input_prompt.endswith("_T2T"):
225
- question_txt = input_prompt.replace("_T2T", "")
226
- res_text = do_t2t(model, question_txt)
227
- elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
228
- "请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
229
- "s2s_no_think"]:
230
- res_text = do_s2s(model, input_wav_path, input_prompt)
231
- elif input_prompt == "THINK":
232
- res_text = do_s2s_think(model, input_wav_path, input_prompt)
233
- elif input_prompt == "s2t_no_think":
234
- res_text = do_s2t4chat(model, input_wav_path, input_prompt)
235
- elif input_prompt == "s2t_think":
236
- res_text = do_s2t4chat_think(model, input_wav_path, input_prompt)
237
- else:
238
- res_text = do_s2t(model, input_wav_path, input_prompt)
239
- res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
240
- "<adult>")
241
-
242
- print("识别结果为:", res_text)
243
- return res_text
244
-
245
-
246
- def do_decode(model, tokenizer, input_wav_path, input_prompt): # 增加 model 和 tokenizer 参数
247
- print(f'使用模型进行推理: input_wav_path={input_wav_path}, input_prompt={input_prompt}')
248
- output_res = true_decode_fuc(model, tokenizer, input_wav_path, input_prompt)
249
- return output_res
250
 
251
 
252
  def save_to_jsonl(if_correct, wav, prompt, res):
@@ -351,24 +384,7 @@ with gr.Blocks() as demo:
351
  else:
352
  input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
353
 
354
- # 3. 调用重构后的推理函数,传入选择的模型
355
- output_res = do_decode(model_to_use, tokenizer_to_use, input_wav_path, input_prompt)
356
-
357
- # 4. 处理输出 (逻辑不变)
358
- wav_path_output = input_wav_path
359
- if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
360
- if isinstance(output_res, list): # TTS case
361
- wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
362
- output_res = "生成的token: " + str(output_res)
363
- elif isinstance(output_res, str) and "|" in output_res: # S2S case
364
- try:
365
- text_res, token_list_str = output_res.split("|")
366
- token_list = json.loads(token_list_str)
367
- wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
368
- output_res = text_res
369
- except (ValueError, json.JSONDecodeError) as e:
370
- print(f"处理S2S输出时出错: {e}")
371
- output_res = f"错误:无法解析模型输出 - {output_res}"
372
 
373
  return output_res, wav_path_output
374
 
 
50
  device = torch.device("cuda")
51
  print("开始加载模型 A...")
52
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
53
+ model_a.eval().cuda()
54
 
55
  print("\n开始加载模型 B...")
56
  if CHECKPOINT_PATH_B is not None:
57
  model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
58
+ model_b.eval().cuda()
59
  else:
60
  model_b, tokenizer_b = None, None
61
+
62
  loaded_models = {
63
  NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
64
  NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
 
68
  print("\n所有模型已加载完毕。")
69
 
70
  cosyvoice = CosyVoice(cosyvoice_model_path)
71
+ cosyvoice.eval().cuda()
72
 
73
  # 将图片转换为 Base64
74
  with open("./tts/assert/实验室.png", "rb") as image_file:
 
113
 
114
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
115
 
116
+ import time
117
+ import datetime
118
+ import torch
119
+ from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  @spaces.GPU
123
+ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
124
+ """
125
+ 合并所有推理逻辑的单个函数,处理所有任务类型
126
+ """
127
+ print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
 
 
 
 
 
128
 
129
+ # 检查音频输入合法性
130
+ if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
131
+ print("音频信息未输入,且不是T2S或T2T任务")
132
+ return "错误:需要音频输入"
 
 
 
 
 
 
 
 
133
 
134
+ # 通用初始化:模型设备设置
 
 
 
 
 
135
  start_time = time.time()
136
+ res_text = None
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ try:
139
+ # 1. 处理TTS任务
140
+ if input_prompt.endswith("_TTS"):
141
+ text_for_tts = input_prompt.replace("_TTS", "")
142
+ # T2S推理逻辑
143
+ res_tensor = model.generate_tts(device=device, text=text_for_tts)[0]
144
+ res_token_list = res_tensor.tolist()
145
+ res_text = res_token_list[:-1]
146
+ print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
147
+
148
+ # 2. 处理自定义提示任务
149
+ elif input_prompt.endswith("_self_prompt"):
150
+ prompt = input_prompt.replace("_self_prompt", "")
151
+ # S2T推理逻辑
152
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
153
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
154
+ if is_npu: torch_npu.npu.synchronize()
155
+ res_text = model.generate(
156
+ wavs=feat,
157
+ wavs_len=feat_lens,
158
+ prompt=prompt,
159
+ cache_implementation="static"
160
+ )[0]
161
+ if is_npu: torch_npu.npu.synchronize()
162
+ print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
163
+
164
+ # 3. 处理T2T任务
165
+ elif input_prompt.endswith("_T2T"):
166
+ question_txt = input_prompt.replace("_T2T", "")
167
+ # T2T推理逻辑
168
+ print(f'开始t2t推理, question_txt: {question_txt}')
169
+ if is_npu: torch_npu.npu.synchronize()
170
+ res_text = model.generate_text2text(
171
+ device=device,
172
+ text=question_txt
173
+ )[0]
174
+ if is_npu: torch_npu.npu.synchronize()
175
+ print(f"T2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
176
+
177
+ # 4. 处理S2S无思考任务
178
+ elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
179
+ "请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
180
+ "s2s_no_think"]:
181
+ # S2S推理逻辑
182
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
183
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
184
+ if is_npu: torch_npu.npu.synchronize()
185
+ output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(
186
+ wavs=feat,
187
+ wavs_len=feat_lens,
188
+ )
189
+ if is_npu: torch_npu.npu.synchronize()
190
+ res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
191
+ print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
192
+
193
+ # 5. 处理S2S有思考任务
194
+ elif input_prompt == "THINK":
195
+ # S2S带思考推理逻辑
196
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
197
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
198
+ if is_npu: torch_npu.npu.synchronize()
199
+ output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(
200
+ wavs=feat,
201
+ wavs_len=feat_lens,
202
+ )
203
+ if is_npu: torch_npu.npu.synchronize()
204
+ res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
205
+ print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
206
+
207
+ # 6. 处理S2T4Chat无思考任务
208
+ elif input_prompt == "s2t_no_think":
209
+ # S2T4Chat推理逻辑
210
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
211
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
212
+ if is_npu: torch_npu.npu.synchronize()
213
+ res_text = model.generate4chat(
214
+ wavs=feat,
215
+ wavs_len=feat_lens,
216
+ cache_implementation="static"
217
+ )[0]
218
+ if is_npu: torch_npu.npu.synchronize()
219
+ print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
220
+
221
+ # 7. 处理S2T4Chat有思考任务
222
+ elif input_prompt == "s2t_think":
223
+ # S2T4Chat带思考推理逻辑
224
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
225
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
226
+ if is_npu: torch_npu.npu.synchronize()
227
+ res_text = model.generate4chat_think(
228
+ wavs=feat,
229
+ wavs_len=feat_lens,
230
+ cache_implementation="static"
231
+ )[0]
232
+ if is_npu: torch_npu.npu.synchronize()
233
+ print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
234
+
235
+ # 8. 处理默认S2T任务
236
+ else:
237
+ # 默认S2T推理逻辑
238
+ feat, feat_lens = get_feat_from_wav_path(input_wav_path)
239
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
240
+ if is_npu: torch_npu.npu.synchronize()
241
+ res_text = model.generate(
242
+ wavs=feat,
243
+ wavs_len=feat_lens,
244
+ prompt=input_prompt,
245
+ cache_implementation="static"
246
+ )[0]
247
+ if is_npu: torch_npu.npu.synchronize()
248
+ print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
249
+ # 替换特定标签
250
+ res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
251
+ "<adult>")
252
 
253
+ except Exception as e:
254
+ print(f"推理过程出错: {str(e)}")
255
+ return f"错误:{str(e)}"
256
+
257
+ output_res =res_text
258
+ # 4. 处理输出 (逻辑不变)
259
+ wav_path_output = input_wav_path
260
+ if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
261
+ if isinstance(output_res, list): # TTS case
262
+ cosyvoice.eval()
263
+ time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
264
+ wav_path = f"./tmp/{time_str}.wav"
265
+ wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
266
+ # wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
267
+ output_res = "生成的token: " + str(output_res)
268
+ elif isinstance(output_res, str) and "|" in output_res: # S2S case
269
+ try:
270
+ text_res, token_list_str = output_res.split("|")
271
+ token_list = json.loads(token_list_str)
272
+ cosyvoice.eval()
273
+ time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
274
+ wav_path = f"./tmp/{time_str}.wav"
275
+ wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
276
+ # wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
277
+ output_res = text_res
278
+ except (ValueError, json.JSONDecodeError) as e:
279
+ print(f"处理S2S输出时出错: {e}")
280
+ output_res = f"错误:无法解析模型输出 - {output_res}"
281
+ return output_res, wav_path_output
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
 
285
  def save_to_jsonl(if_correct, wav, prompt, res):
 
384
  else:
385
  input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
386
 
387
+ output_res, wav_path_output = true_decode_fuc(model_to_use, tokenizer_to_use, input_wav_path, input_prompt,task_choice ,prompt_speech_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  return output_res, wav_path_output
390
 
wenet/osum_echat/init_llmasr.py CHANGED
@@ -121,7 +121,7 @@ def init_llmasr(args, configs, is_inference=False):
121
  elif fire_module == "link_and_lora":
122
  if k.startswith("encoder"):
123
  p.requires_grad = False
124
- logging.info(f"{k} {p.requires_grad} {p.shape} {p.dtype}")
125
  logging.info('OSUM-EChat:冻结完毕')
126
  logging.info(configs)
127
 
 
121
  elif fire_module == "link_and_lora":
122
  if k.startswith("encoder"):
123
  p.requires_grad = False
124
+ # logging.info(f"{k} {p.requires_grad} {p.shape} {p.dtype}")
125
  logging.info('OSUM-EChat:冻结完毕')
126
  logging.info(configs)
127