HY-2012 commited on
Commit
2e583ba
·
verified ·
1 Parent(s): 5cb4c08

Update the TTS output format

Browse files
Files changed (1) hide show
  1. ax_spoken_communication_demo.py +666 -747
ax_spoken_communication_demo.py CHANGED
@@ -1,748 +1,667 @@
1
- import os
2
- import time
3
- import librosa
4
- import torch
5
- import argparse
6
- import soundfile as sf
7
- import cn2an
8
- import requests
9
- import re
10
- import numpy as np
11
- import onnxruntime as ort
12
- import axengine as axe
13
-
14
- # 导入SenseVoice相关模块
15
- from model import SinusoidalPositionEncoder
16
- from utils.ax_model_bin import AX_SenseVoiceSmall
17
- from utils.ax_vad_bin import AX_Fsmn_vad
18
- from utils.vad_utils import merge_vad
19
- from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
20
-
21
- # 导入MeloTTS相关模块
22
- from libmelotts.python.split_utils import split_sentence
23
- from libmelotts.python.text import cleaned_text_to_sequence
24
- from libmelotts.python.text.cleaner import clean_text
25
- from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP
26
-
27
- # 配置参数
28
- # tts 参数
29
- TTS_MODEL_DIR = "libmelotts/models"
30
- TTS_MODEL_FILES = {
31
- "g": "g-zh_mix_en.bin",
32
- "encoder": "encoder-zh.onnx",
33
- "decoder": "decoder-zh.axmodel"
34
- }
35
-
36
- # Qwen大模型API参数
37
- QWEN_API_URL = "" # API服务地址 http://10.126.29.158:8000
38
-
39
-
40
- # TTS辅助函数(从melotts.py移植)
41
- def intersperse(lst, item):
42
- result = [item] * (len(lst) * 2 + 1)
43
- result[1::2] = lst
44
- return result
45
-
46
- # 处理字符无法识别
47
- def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
48
- """音素处理:确保所有数组长度一致"""
49
- try:
50
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
51
-
52
- # 特殊音素直接映射为空字符串
53
- phone_mapping = {
54
- 'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '',
55
- 'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '',
56
- 'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '',
57
- }
58
-
59
- # 同步处理 phone 和 tone,确保它们长度一致
60
- processed_phone = []
61
- processed_tone = []
62
- removed_symbols = set()
63
-
64
- for p, t in zip(phone, tone):
65
- if p in phone_mapping:
66
- # 特殊音素直接删除,同时删除对应的 tone
67
- removed_symbols.add(p)
68
- elif p in symbol_to_id:
69
- # 正常音素保留,同时保留对应的 tone
70
- processed_phone.append(p)
71
- processed_tone.append(t)
72
- else:
73
- # 其他未知音素也删除
74
- removed_symbols.add(p)
75
-
76
- # 记录被删除的音素
77
- if removed_symbols:
78
- print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}")
79
- print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}")
80
- print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}")
81
-
82
- # 如果没有有效音素,使用默认音素,
83
- if not processed_phone:
84
- print("[警告] 没有有效音素,使用默认中文音素")
85
- processed_phone = ['ni', 'hao']
86
- processed_tone = ['1', '3']
87
- word2ph = [1, 1]
88
-
89
- # 确保 word2ph 的长度与处理后的音素序列匹配
90
- if len(processed_phone) != len(phone):
91
- print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}")
92
- # 简单处理:重新计算 word2ph
93
- word2ph = [1] * len(processed_phone)
94
-
95
- phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id)
96
-
97
- phone = intersperse(phone, 0)
98
- tone = intersperse(tone, 0)
99
- language = intersperse(language, 0)
100
-
101
- phone = np.array(phone, dtype=np.int32)
102
- tone = np.array(tone, dtype=np.int32)
103
- language = np.array(language, dtype=np.int32)
104
- word2ph = np.array(word2ph, dtype=np.int32) * 2
105
- word2ph[0] += 1
106
- return phone, tone, language, norm_text, word2ph
107
-
108
- except Exception as e:
109
- print(f"[错误] 文本处理失败: {e}")
110
- import traceback
111
- traceback.print_exc()
112
- raise e
113
-
114
-
115
- def audio_numpy_concat(segment_data_list, sr, speed=1.):
116
- audio_segments = []
117
- for segment_data in segment_data_list:
118
- audio_segments += segment_data.reshape(-1).tolist()
119
- audio_segments += [0] * int((sr * 0.05) / speed)
120
- audio_segments = np.array(audio_segments).astype(np.float32)
121
- return audio_segments
122
-
123
-
124
- def merge_sub_audio(sub_audio_list, pad_size, audio_len):
125
- # Average pad part
126
- if pad_size > 0:
127
- for i in range(len(sub_audio_list) - 1):
128
- sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size]
129
- sub_audio_list[i][-pad_size:] /= 2
130
- if i > 0:
131
- sub_audio_list[i] = sub_audio_list[i][pad_size:]
132
-
133
- sub_audio = np.concatenate(sub_audio_list, axis=-1)
134
- return sub_audio[:audio_len]
135
-
136
-
137
- def calc_word2pronoun(word2ph, pronoun_lens):
138
- indice = [0]
139
- for ph in word2ph[:-1]:
140
- indice.append(indice[-1] + ph)
141
- word2pronoun = []
142
- for i, ph in zip(indice, word2ph):
143
- word2pronoun.append(np.sum(pronoun_lens[i : i + ph]))
144
- return word2pronoun
145
-
146
-
147
- def generate_slices(word2pronoun, dec_len):
148
- pn_start, pn_end = 0, 0
149
- zp_start, zp_end = 0, 0
150
- zp_len = 0
151
- pn_slices = []
152
- zp_slices = []
153
- while pn_end < len(word2pronoun):
154
- # 前一个slice长度大于2 且 加上现在这个字没有超过dec_len,则往前overlap两个字
155
- if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len:
156
- zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end])
157
- zp_start = zp_end - zp_len
158
- pn_start = pn_end - 2
159
- else:
160
- zp_len = 0
161
- zp_start = zp_end
162
- pn_start = pn_end
163
-
164
- while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len:
165
- zp_len += word2pronoun[pn_end]
166
- pn_end += 1
167
- zp_end = zp_start + zp_len
168
- pn_slices.append(slice(pn_start, pn_end))
169
- zp_slices.append(slice(zp_start, zp_end))
170
- return pn_slices, zp_slices
171
-
172
-
173
- # 确认中英文
174
- def lang_detect_with_regex(text):
175
- """
176
- 语言识别
177
- """
178
- # 移除所有数字
179
- text_without_digits = re.sub(r'\d+', '', text)
180
-
181
- if not text_without_digits:
182
- return 'unknown'
183
-
184
- # 检查是否包含中文字符 #中文优先
185
- if re.search(r'[\u4e00-\u9fff]', text_without_digits):
186
- return 'chinese'
187
- else:
188
- # 检查是否包含英文字母
189
- if re.search(r'[a-zA-Z]', text_without_digits):
190
- return 'english'
191
- else:
192
- return 'unknown'
193
-
194
- class QwenTranslationAPI:
195
- def __init__(self, api_url=QWEN_API_URL):
196
- self.api_url = api_url
197
- self.session_id = f"speech_translate_{int(time.time())}"
198
-
199
- def reset_context(self):
200
- """重置API上下文"""
201
- try:
202
- reset_url = f"{self.api_url}/api/reset"
203
- response = requests.post(reset_url, json={}, timeout=5)
204
- if response.status_code == 200:
205
- print("[API] 上下文重置成功")
206
- return True
207
- else:
208
- print(f"[API] 重置失败,状态码: {response.status_code}")
209
- except Exception as e:
210
- print(f"[API] 重置上下文失败: {e}")
211
- return False
212
-
213
- def translate(self, text_content, max_retries=3, timeout=120):
214
- """调用千问API进行处理"""
215
- if not text_content or text_content.strip() == "":
216
- return "输入文本为空"
217
-
218
- if lang_detect_with_regex(text_content)=='chinese':
219
- prompt_f = "回答(限制在100个字以内)"
220
- else:
221
- prompt_f = "回答(限制在100个字以内)"
222
-
223
- prompt = f"{prompt_f}:{text_content}"
224
- print(f"[API] 发送请求: {prompt}")
225
-
226
- for attempt in range(max_retries):
227
- try:
228
- # 第一步:发送生成请求
229
- generate_url = f"{self.api_url}/api/generate"
230
- payload = {
231
- "prompt": prompt,
232
- "temperature": 0.1, # 降低温度以获得更确定的结果
233
- "repetition_penalty": 1.0,
234
- "top-p": 0.9,
235
- "top-k": 40,
236
- "max_new_tokens": 512
237
- }
238
-
239
- print(f"[API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})")
240
- response = requests.post(generate_url, json=payload, timeout=30)
241
- response.raise_for_status()
242
- print("[API] 生成请求成功")
243
-
244
- # 第二步:轮询获取结果并合并所有chunk
245
- result_url = f"{self.api_url}/api/generate_provider"
246
- start_time = time.time()
247
- full_translation = ""
248
- last_chunk = ""
249
- error_detected = False
250
-
251
- while time.time() - start_time < timeout:
252
- try:
253
- result_response = requests.get(result_url, timeout=10)
254
- result_data = result_response.json()
255
-
256
- # 获取当前chunk
257
- current_chunk = result_data.get("response", "")#.strip()
258
-
259
- # 检查是否有setkvcache failed错误
260
- if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower():
261
- print(f"[API] 检测到错误: {current_chunk}")
262
- error_detected = True
263
- self.reset_context()
264
- break
265
-
266
- full_translation += current_chunk
267
-
268
- # 检查是否完成
269
- if result_data.get("done", False):
270
- # 确保获取到完整的结果
271
- print(f"[API] 完成: {full_translation}")
272
- return full_translation
273
-
274
- time.sleep(0.05)
275
-
276
- except requests.exceptions.RequestException as e:
277
- print(f"[API] 轮询请求失败: {e}")
278
- if time.time() - start_time > timeout:
279
- break
280
- continue
281
-
282
- # 如果检测到错误且还有重试次数,继续重试
283
- if error_detected and attempt < max_retries - 1:
284
- print(f"[API] 等待1秒后重试...")
285
- time.sleep(1)
286
- continue
287
-
288
- print(f"[API] 轮询超时,尝试第 {attempt + 1} 次重试")
289
-
290
- except requests.exceptions.RequestException as e:
291
- print(f"[API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
292
- if attempt < max_retries - 1:
293
- wait_time = 2 ** attempt # 指数退避
294
- print(f"[API] 等待 {wait_time} 秒后重试...")
295
- time.sleep(wait_time)
296
- else:
297
- return f"失败: {str(e)}"
298
- except Exception as e:
299
- print(f"[API] 过程出错: {e}")
300
- return f"失败: {str(e)}"
301
-
302
- return "超时,请检查API服务状态"
303
-
304
- class SpeechTranslationPipeline:
305
- def __init__(self,
306
- tts_model_dir, tts_model_files,
307
- asr_model_dir="ax_model", seq_len=132,
308
- tts_dec_len=128, sample_rate=44100, tts_speed=0.8,
309
- qwen_api_url=QWEN_API_URL):
310
- self.tts_model_dir = tts_model_dir
311
- self.tts_model_files = tts_model_files
312
- self.asr_model_dir = asr_model_dir
313
- self.seq_len = seq_len
314
- self.tts_dec_len = tts_dec_len
315
- self.sample_rate = sample_rate
316
- self.tts_speed = tts_speed
317
- self.qwen_api_url = qwen_api_url
318
-
319
- # 初始化ASR模型
320
- self._init_asr_models()
321
-
322
- # 初始化TTS模型
323
- self._init_tts_models()
324
-
325
- # 初始化API
326
- self.translator = QwenTranslationAPI(api_url=qwen_api_url)
327
-
328
- # 验证所有必需文件存在
329
- self._validate_files()
330
-
331
- def _init_asr_models(self):
332
- """初始化语音识别相关模型"""
333
- print("Initializing SenseVoice models...")
334
-
335
- # VAD模型
336
- self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
337
-
338
- # 位置编码
339
- self.embed = SinusoidalPositionEncoder()
340
- self.position_encoding = self.embed.get_position_encoding(
341
- torch.randn(1, self.seq_len, 560)).numpy()
342
-
343
- # ASR模型
344
- self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
345
-
346
- # Tokenizer
347
- tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
348
- self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
349
-
350
- print("SenseVoice models initialized successfully.")
351
-
352
- def _init_tts_models(self):
353
- """初始化TTS相关模型"""
354
- print("Initializing MeloTTS models...")
355
- init_start = time.time()
356
-
357
- # 加载encoder和decoder模型
358
- enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"])
359
- dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"])
360
-
361
- model_load_start = time.time()
362
- self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions())
363
- self.sess_dec = axe.InferenceSession(dec_model)
364
- print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms")
365
-
366
- # 加载静态输入g
367
- g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"])
368
- self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1)
369
-
370
- # 设置语言和symbol映射(默认支持中英混合)
371
- self.tts_language = "ZH_MIX_EN"
372
- self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])}
373
-
374
- # 预热:提前加载所有懒加载的模块(这是主要耗时部分)
375
- print(" Warming up TTS modules (loading language models, tokenizers, etc.)...")
376
- warmup_start = time.time()
377
-
378
- # 中英混合预热
379
- try:
380
- warmup_start_mix = time.time()
381
- warmup_text_mix = "这是一个test测试。"
382
- _, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id)
383
- print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms")
384
- except Exception as e:
385
- print(f" Warning: Mixed warm-up failed: {e}")
386
-
387
- total_init_time = (time.time() - init_start) * 1000
388
- print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)")
389
-
390
- def _validate_files(self):
391
- """验证所有必需的文件都存在"""
392
- # 检查TTS相关文件
393
- for key, filename in self.tts_model_files.items():
394
- filepath = os.path.join(self.tts_model_dir, filename)
395
- if not os.path.exists(filepath):
396
- raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
397
-
398
- # 检查API服务是否可用(可选)
399
- try:
400
- response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5)
401
- print("[API检查] 千问API服务连接正常")
402
- except:
403
- print("[API警告] 无法连接到千问API服务,请确保已启动API服务")
404
-
405
- def speech_recognition(self, speech, fs):
406
- """
407
- 第一步:语音识别(ASR)
408
- """
409
- speech_lengths = len(speech)
410
-
411
- # VAD处理
412
- print("Running VAD...")
413
- vad_start_time = time.time()
414
- res_vad = self.model_vad(speech)[0]
415
- vad_segments = merge_vad(res_vad, 15 * 1000)
416
- vad_time_cost = time.time() - vad_start_time
417
- print(f"VAD processing time: {vad_time_cost:.2f} seconds")
418
- print(f"VAD segments detected: {len(vad_segments)}")
419
-
420
- # ASR处理
421
- print("Running ASR...")
422
- asr_start_time = time.time()
423
- all_results = ""
424
-
425
- # 遍历每个VAD片段并处理
426
- for i, segment in enumerate(vad_segments):
427
- segment_start, segment_end = segment
428
- start_sample = int(segment_start / 1000 * fs)
429
- end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
430
- segment_speech = speech[start_sample:end_sample]
431
-
432
- # 为当前片段创建临时文件
433
- segment_filename = f"temp_segment_{i}.wav"
434
- sf.write(segment_filename, segment_speech, fs)
435
-
436
- # 对当前片段进行识别
437
- try:
438
- segment_res = self.model_bin(
439
- segment_filename,
440
- "auto", # 语言自动检测
441
- True, # withitn
442
- self.position_encoding,
443
- tokenizer=self.tokenizer,
444
- )
445
-
446
- all_results += segment_res
447
-
448
- # 清理临时文件
449
- if os.path.exists(segment_filename):
450
- os.remove(segment_filename)
451
-
452
- except Exception as e:
453
- if os.path.exists(segment_filename):
454
- os.remove(segment_filename)
455
- print(f"Error processing segment {i}: {e}")
456
- continue
457
-
458
- asr_time_cost = time.time() - asr_start_time
459
- print(f"ASR processing time: {asr_time_cost:.2f} seconds")
460
- print(f"ASR Result: {all_results}")
461
-
462
- return all_results.strip()
463
-
464
- def run_translation(self, text_content):
465
- """
466
- 第二步:调用Qwen大模型API处理
467
- """
468
- print("Starting translation via API...")
469
- translation_start_time = time.time()
470
-
471
- # 使用API进行处理
472
- translate_content = self.translator.translate(text_content)
473
-
474
- translation_time_cost = time.time() - translation_start_time
475
- print(f"Translation processing time: {translation_time_cost:.2f} seconds")
476
- print(f"Translation Result: {translate_content}")
477
-
478
- return translate_content
479
-
480
- def run_tts(self, translate_content, output_dir, output_wav=None):
481
- """
482
- 第三步:使用TTS模型合成语音
483
- """
484
- output_path = os.path.join(output_dir, output_wav)
485
-
486
- try:
487
- # 处理中文文本中的数字
488
- if lang_detect_with_regex(translate_content) == "chinese":
489
- translate_content = cn2an.transform(translate_content, "an2cn")
490
-
491
- print(f"TTS synthesis for text: {translate_content}")
492
-
493
- # 分句
494
- sens = split_sentence(translate_content, language_str=self.tts_language)
495
- print(f"Text split into {len(sens)} sentences")
496
-
497
- # 最终音频列表
498
- audio_list = []
499
-
500
- # 遍历每个句子
501
- for n, se in enumerate(sens):
502
- # 处理英文大小写连接
503
- if self.tts_language in ['EN', 'ZH_MIX_EN']:
504
- se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se)
505
-
506
- print(f"Processing sentence[{n}]: {se}")
507
-
508
- # 转换文本为音素和音调
509
- phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer(
510
- se, self.tts_language, symbol_to_id=self.symbol_to_id)
511
-
512
- # 运行encoder
513
- encoder_start = time.time()
514
- z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={
515
- 'phone': phones, 'g': self.tts_g,
516
- 'tone': tones, 'language': lang_ids,
517
- 'noise_scale': np.array([0], dtype=np.float32),
518
- 'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32),
519
- 'noise_scale_w': np.array([0], dtype=np.float32),
520
- 'sdp_ratio': np.array([0], dtype=np.float32)})
521
- print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms")
522
-
523
- # 计算每个词的发音长度
524
- word2pronoun = calc_word2pronoun(word2ph, pronoun_lens)
525
- # 生成切片
526
- pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len)
527
-
528
- audio_len = audio_len[0]
529
- sub_audio_list = []
530
-
531
- for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)):
532
- zp_slice = z_p[..., zs]
533
-
534
- # Padding前zp的长度
535
- sub_dec_len = zp_slice.shape[-1]
536
- # Padding前输出音频的长度
537
- sub_audio_len = 512 * sub_dec_len
538
-
539
- # Padding到dec_len
540
- if zp_slice.shape[-1] < self.tts_dec_len:
541
- zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1)
542
-
543
- decoder_start = time.time()
544
- audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten()
545
-
546
- # 处理overlap
547
- audio_start = 0
548
- if len(sub_audio_list) > 0:
549
- if pn_slices[i - 1].stop > ps.start:
550
- # 去掉第一个字
551
- audio_start = 512 * word2pronoun[ps.start]
552
-
553
- audio_end = sub_audio_len
554
- if i < len(pn_slices) - 1:
555
- if ps.stop > pn_slices[i + 1].start:
556
- # 去掉最后一个字
557
- audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1]
558
-
559
- audio = audio[audio_start:audio_end]
560
- print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms")
561
- sub_audio_list.append(audio)
562
-
563
- # 合并子音频
564
- sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len)
565
- audio_list.append(sub_audio)
566
-
567
- # 拼接所有句子的音频
568
- audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed)
569
-
570
- # 保存音频文件
571
- sf.write(output_path, audio, self.sample_rate)
572
- print(f"TTS audio saved to {output_path}")
573
-
574
- return output_path
575
-
576
- except Exception as e:
577
- print(f"TTS synthesis failed: {e}")
578
- import traceback
579
- traceback.print_exc()
580
- raise e
581
-
582
- def full_pipeline(self, speech, fs, output_dir=None, output_tts=None):
583
- """
584
- 完整Pipeline:语音识别 -> qwen -> TTS合成
585
- """
586
-
587
- # 第一步:语音识别
588
- print("\n----------------------VAD+ASR----------------------------\n")
589
- start_time = time.time() # 记录开始时间
590
- text_content = self.speech_recognition(speech, fs)
591
- asr_time = time.time() - start_time # 计算耗时
592
- print(f"语音识别耗时: {asr_time:.2f}")
593
-
594
- if not text_content or text_content.strip() == "":
595
- raise ValueError("ASR未能识别出有效文本")
596
-
597
- # 第二步:qwen
598
- print("\n---------------------Qwen---------------------------\n")
599
- start_time = time.time() # 记录开始时间
600
- translate_content = self.run_translation(text_content)
601
- translate_time = time.time() - start_time # 计算耗时
602
- print(f"qwen耗时: {translate_time:.2f} 秒")
603
-
604
- # 第三步:TTS合成
605
- print("-------------------------TTS-------------------------------\n")
606
- start_time = time.time() # 记录开始时间
607
- output_path = self.run_tts(translate_content, output_dir, output_tts)
608
- tts_time = time.time() - start_time # 计算耗时
609
- print(f"TTS合成耗时: {tts_time:.2f} 秒")
610
-
611
- return {
612
- "original_text": text_content,
613
- "translated_text": translate_content,
614
- "audio_path": output_path
615
- }
616
-
617
- def main():
618
- parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline")
619
- parser.add_argument("--audio_dir", type=str, default="./input_question", help="Input audio directory path")
620
- parser.add_argument("--output_dir", type=str, default="./output_answer", help="Output directory")
621
- parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API server URL")
622
-
623
- args = parser.parse_args()
624
- print("-------------------START------------------------\n")
625
- os.makedirs(args.output_dir, exist_ok=True)
626
-
627
- # 检查音频目录是否存在
628
- if not os.path.exists(args.audio_dir):
629
- print(f"错误: 音频目录不存在: {args.audio_dir}")
630
- return
631
-
632
- # 获取音频目录中的所有.wav文件
633
- audio_files = []
634
- for file in os.listdir(args.audio_dir):
635
- if file.lower().endswith(('.wav', '.mp3')):
636
- audio_files.append(os.path.join(args.audio_dir, file))
637
-
638
- # 如果没有找到音频文件
639
- if not audio_files:
640
- print(f"错误: 在目录 {args.audio_dir} 中没有找到音频文件")
641
- return
642
-
643
- # 按文件名排序,确保处理顺序
644
- audio_files.sort()
645
- print(f"找到 {len(audio_files)} 个音频文件: {[os.path.basename(f) for f in audio_files]}")
646
-
647
- # 初始化pipeline(只需一次)
648
- pipeline = SpeechTranslationPipeline(
649
- tts_model_dir=TTS_MODEL_DIR,
650
- tts_model_files=TTS_MODEL_FILES,
651
- asr_model_dir="ax_model",
652
- seq_len=132,
653
- tts_dec_len=128,
654
- sample_rate=44100,
655
- tts_speed=0.8,
656
- qwen_api_url=args.api_url
657
- )
658
-
659
- # 处理每个音频文件
660
- all_results = []
661
- total_start_time = time.time()
662
-
663
- for i, audio_file in enumerate(audio_files):
664
- print(f"\n{'='*60}")
665
- print(f"处理第 {i+1}/{len(audio_files)} 个音频文件: {os.path.basename(audio_file)}")
666
- print(f"{'='*60}")
667
-
668
- file_start_time = time.time()
669
-
670
- try:
671
- # 加载音频
672
- speech, fs = librosa.load(audio_file, sr=None)
673
- if fs != 16000:
674
- print(f"重采样音频从 {fs}Hz 到 16000Hz")
675
- speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
676
- fs = 16000
677
- audio_duration = librosa.get_duration(y=speech, sr=fs)
678
-
679
- # 生成输出文件名
680
- base_name = os.path.splitext(os.path.basename(audio_file))[0]
681
- output_tts = f"{base_name}_answer.wav"
682
-
683
- # 运行pipeline
684
- result = pipeline.full_pipeline(speech, fs, args.output_dir, output_tts)
685
-
686
- # 计算处理时间
687
- file_time_cost = time.time() - file_start_time
688
-
689
- out_wav = os.path.join(args.output_dir,output_tts)
690
- speech, fs = librosa.load(out_wav, sr=None)
691
- output_duration = librosa.get_duration(y=speech, sr=fs)
692
- rtf = file_time_cost / output_duration
693
-
694
- # 添加文件信息到结果
695
- result.update({
696
- "audio_file": audio_file,
697
- "processing_time": file_time_cost,
698
- "output_duration": output_duration,
699
- "rtf": rtf
700
- })
701
-
702
- all_results.append(result)
703
-
704
- print(f"\n文件处理完成: {os.path.basename(audio_file)}")
705
- print(f"原始文本: {result['original_text']}")
706
- print(f"回答文本: {result['translated_text']}")
707
- print(f"生成音频: {result['audio_path']}")
708
- print(f"处理时间: {file_time_cost:.2f} 秒")
709
- print(f"音频时长: {output_duration:.2f} 秒")
710
- print(f"RTF: {rtf:.2f}")
711
-
712
- except Exception as e:
713
- print(f"处理文件 {audio_file} 时出错: {e}")
714
- import traceback
715
- traceback.print_exc()
716
- continue
717
-
718
- # 输出总体结果
719
- total_time_cost = time.time() - total_start_time
720
- print(f"\n{'='*80}")
721
- print("所有文件处理完成!")
722
- print(f"{'='*80}")
723
- print(f"总共处理了 {len(all_results)} 个文件")
724
- print(f"总处理时间: {total_time_cost:.2f} 秒")
725
-
726
- # 保存汇总结果
727
- summary_file = os.path.join(args.output_dir, "processing_summary.txt")
728
- with open(summary_file, 'w', encoding='utf-8') as f:
729
- f.write("批量处理结果汇总\n")
730
- f.write("=" * 50 + "\n\n")
731
-
732
- for i, result in enumerate(all_results):
733
- f.write(f"文件 {i+1}: {os.path.basename(result['audio_file'])}\n")
734
- f.write(f" 原始文本: {result['original_text']}\n")
735
- f.write(f" 回答结果: {result['translated_text']}\n")
736
- f.write(f" 合成音频: {os.path.basename(result['audio_path'])}\n")
737
- f.write(f" 处理时间: {result['processing_time']:.2f} 秒\n")
738
- f.write(f" 音频时长: {result['output_duration']:.2f} 秒\n")
739
- f.write(f" RTF: {result['rtf']:.2f}\n")
740
- f.write("-" * 50 + "\n")
741
-
742
- f.write(f"\n总计: {len(all_results)} 个文件\n")
743
- f.write(f"总处理时间: {total_time_cost:.2f} 秒\n")
744
-
745
- print(f"详细结果已保存到: {summary_file}")
746
-
747
- if __name__ == "__main__":
748
  main()
 
1
+ import os
2
+ import time
3
+ import librosa
4
+ import torch
5
+ import argparse
6
+ import soundfile as sf
7
+ import cn2an
8
+ import requests
9
+ import re
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ import axengine as axe
13
+
14
+ from model import SinusoidalPositionEncoder
15
+ from utils.ax_model_bin import AX_SenseVoiceSmall
16
+ from utils.ax_vad_bin import AX_Fsmn_vad
17
+ from utils.vad_utils import merge_vad
18
+ from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
19
+
20
+ from libmelotts.python.split_utils import split_sentence
21
+ from libmelotts.python.text import cleaned_text_to_sequence
22
+ from libmelotts.python.text.cleaner import clean_text
23
+ from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP
24
+
25
+ # 配置参数
26
+ TTS_MODEL_DIR = "libmelotts/models"
27
+ TTS_MODEL_FILES = {
28
+ "g": "g-zh_mix_en.bin",
29
+ "encoder": "encoder-zh.onnx",
30
+ "decoder": "decoder-zh.axmodel"
31
+ }
32
+
33
+ QWEN_API_URL = ""
34
+
35
+ def intersperse(lst, item):
36
+ result = [item] * (len(lst) * 2 + 1)
37
+ result[1::2] = lst
38
+ return result
39
+
40
+ def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
41
+ """音素处理:确保所有数组长度一致"""
42
+ try:
43
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
44
+
45
+ phone_mapping = {
46
+ 'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '',
47
+ 'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '',
48
+ 'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '',
49
+ }
50
+
51
+ processed_phone = []
52
+ processed_tone = []
53
+ removed_symbols = set()
54
+
55
+ for p, t in zip(phone, tone):
56
+ if p in phone_mapping:
57
+ removed_symbols.add(p)
58
+ elif p in symbol_to_id:
59
+ processed_phone.append(p)
60
+ processed_tone.append(t)
61
+ else:
62
+ removed_symbols.add(p)
63
+
64
+ if removed_symbols:
65
+ print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素")
66
+
67
+ if not processed_phone:
68
+ print("[警告] 没有有效音素,使用默认中文音素")
69
+ processed_phone = ['ni', 'hao']
70
+ processed_tone = ['1', '3']
71
+ word2ph = [1, 1]
72
+
73
+ if len(processed_phone) != len(phone):
74
+ word2ph = [1] * len(processed_phone)
75
+
76
+ phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id)
77
+
78
+ phone = intersperse(phone, 0)
79
+ tone = intersperse(tone, 0)
80
+ language = intersperse(language, 0)
81
+
82
+ phone = np.array(phone, dtype=np.int32)
83
+ tone = np.array(tone, dtype=np.int32)
84
+ language = np.array(language, dtype=np.int32)
85
+ word2ph = np.array(word2ph, dtype=np.int32) * 2
86
+ word2ph[0] += 1
87
+ return phone, tone, language, norm_text, word2ph
88
+
89
+ except Exception as e:
90
+ print(f"[错误] 文本处理失败: {e}")
91
+ import traceback
92
+ traceback.print_exc()
93
+ raise e
94
+
95
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
96
+ """优化版音频拼接"""
97
+ if not segment_data_list:
98
+ return np.array([], dtype=np.float32)
99
+
100
+ total_len = sum(len(segment) for segment in segment_data_list)
101
+ pause_samples = int((sr * 0.05) / speed)
102
+ total_len += pause_samples * (len(segment_data_list) - 1)
103
+
104
+ audio_segments = np.zeros(total_len, dtype=np.float32)
105
+ current_pos = 0
106
+
107
+ for i, segment_data in enumerate(segment_data_list):
108
+ segment_len = len(segment_data)
109
+ segment_flat = segment_data.reshape(-1)
110
+
111
+ audio_segments[current_pos:current_pos + segment_len] = segment_flat
112
+ current_pos += segment_len
113
+
114
+ if i < len(segment_data_list) - 1:
115
+ current_pos += pause_samples
116
+
117
+ return audio_segments
118
+
119
+ def merge_sub_audio(sub_audio_list, pad_size, audio_len):
120
+ if pad_size > 0:
121
+ for i in range(len(sub_audio_list) - 1):
122
+ sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size]
123
+ sub_audio_list[i][-pad_size:] /= 2
124
+ if i > 0:
125
+ sub_audio_list[i] = sub_audio_list[i][pad_size:]
126
+
127
+ sub_audio = np.concatenate(sub_audio_list, axis=-1)
128
+ return sub_audio[:audio_len]
129
+
130
+ def calc_word2pronoun(word2ph, pronoun_lens):
131
+ indice = [0]
132
+ for ph in word2ph[:-1]:
133
+ indice.append(indice[-1] + ph)
134
+ word2pronoun = []
135
+ for i, ph in zip(indice, word2ph):
136
+ word2pronoun.append(np.sum(pronoun_lens[i : i + ph]))
137
+ return word2pronoun
138
+
139
+ def generate_slices(word2pronoun, dec_len):
140
+ pn_start, pn_end = 0, 0
141
+ zp_start, zp_end = 0, 0
142
+ zp_len = 0
143
+ pn_slices = []
144
+ zp_slices = []
145
+ while pn_end < len(word2pronoun):
146
+ if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len:
147
+ zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end])
148
+ zp_start = zp_end - zp_len
149
+ pn_start = pn_end - 2
150
+ else:
151
+ zp_len = 0
152
+ zp_start = zp_end
153
+ pn_start = pn_end
154
+
155
+ while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len:
156
+ zp_len += word2pronoun[pn_end]
157
+ pn_end += 1
158
+ zp_end = zp_start + zp_len
159
+ pn_slices.append(slice(pn_start, pn_end))
160
+ zp_slices.append(slice(zp_start, zp_end))
161
+ return pn_slices, zp_slices
162
+
163
+ def lang_detect_with_regex(text):
164
+ text_without_digits = re.sub(r'\d+', '', text)
165
+
166
+ if not text_without_digits:
167
+ return 'unknown'
168
+
169
+ if re.search(r'[\u4e00-\u9fff]', text_without_digits):
170
+ return 'chinese'
171
+ else:
172
+ if re.search(r'[a-zA-Z]', text_without_digits):
173
+ return 'english'
174
+ else:
175
+ return 'unknown'
176
+
177
+ class QwenTranslationAPI:
178
+ def __init__(self, api_url=QWEN_API_URL):
179
+ self.api_url = api_url
180
+ self.session_id = f"speech_translate_{int(time.time())}"
181
+
182
+ def reset_context(self):
183
+ try:
184
+ reset_url = f"{self.api_url}/api/reset"
185
+ response = requests.post(reset_url, json={}, timeout=5)
186
+ if response.status_code == 200:
187
+ print("[API] 上下文重置成功")
188
+ return True
189
+ else:
190
+ print(f"[API] 重置失败,状态码: {response.status_code}")
191
+ except Exception as e:
192
+ print(f"[API] 重置上下文失败: {e}")
193
+ return False
194
+
195
+ def translate(self, text_content, max_retries=3, timeout=120):
196
+ if not text_content or text_content.strip() == "":
197
+ return "输入文本为空"
198
+
199
+ if lang_detect_with_regex(text_content)=='chinese':
200
+ prompt_f = "回答(限制在100个字以内)"
201
+ else:
202
+ prompt_f = "回答(限制在100个字以内)"
203
+
204
+ prompt = f"{prompt_f}:{text_content}"
205
+ print(f"[API] 发送请求: {prompt}")
206
+
207
+ for attempt in range(max_retries):
208
+ try:
209
+ generate_url = f"{self.api_url}/api/generate"
210
+ payload = {
211
+ "prompt": prompt,
212
+ "temperature": 0.1,
213
+ "repetition_penalty": 1.0,
214
+ "top-p": 0.9,
215
+ "top-k": 40,
216
+ "max_new_tokens": 512
217
+ }
218
+
219
+ print(f"[API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})")
220
+ response = requests.post(generate_url, json=payload, timeout=30)
221
+ response.raise_for_status()
222
+ print("[API] 生成请求成功")
223
+
224
+ result_url = f"{self.api_url}/api/generate_provider"
225
+ start_time = time.time()
226
+ full_translation = ""
227
+ error_detected = False
228
+
229
+ while time.time() - start_time < timeout:
230
+ try:
231
+ result_response = requests.get(result_url, timeout=10)
232
+ result_data = result_response.json()
233
+
234
+ current_chunk = result_data.get("response", "")
235
+
236
+ if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower():
237
+ print(f"[API] 检测到错误: {current_chunk}")
238
+ error_detected = True
239
+ self.reset_context()
240
+ break
241
+
242
+ full_translation += current_chunk
243
+
244
+ if result_data.get("done", False):
245
+ print(f"[API] 完成: {full_translation}")
246
+ return full_translation
247
+
248
+ time.sleep(0.05)
249
+
250
+ except requests.exceptions.RequestException as e:
251
+ print(f"[API] 轮询请求失败: {e}")
252
+ if time.time() - start_time > timeout:
253
+ break
254
+ continue
255
+
256
+ if error_detected and attempt < max_retries - 1:
257
+ print(f"[API] 等待1秒后重试...")
258
+ time.sleep(1)
259
+ continue
260
+
261
+ print(f"[API] 轮询超时,尝试第 {attempt + 1} 次重试")
262
+
263
+ except requests.exceptions.RequestException as e:
264
+ print(f"[API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
265
+ if attempt < max_retries - 1:
266
+ wait_time = 2 ** attempt
267
+ print(f"[API] 等待 {wait_time} 秒后重试...")
268
+ time.sleep(wait_time)
269
+ else:
270
+ return f"失败: {str(e)}"
271
+ except Exception as e:
272
+ print(f"[API] 过程出错: {e}")
273
+ return f"失败: {str(e)}"
274
+
275
+ return "超时,请检查API服务状态"
276
+
277
+ class SpeechTranslationPipeline:
278
+ def __init__(self,
279
+ tts_model_dir, tts_model_files,
280
+ asr_model_dir="ax_model", seq_len=132,
281
+ tts_dec_len=128, sample_rate=44100, tts_speed=0.8,
282
+ qwen_api_url=QWEN_API_URL):
283
+ self.tts_model_dir = tts_model_dir
284
+ self.tts_model_files = tts_model_files
285
+ self.asr_model_dir = asr_model_dir
286
+ self.seq_len = seq_len
287
+ self.tts_dec_len = tts_dec_len
288
+ self.sample_rate = sample_rate
289
+ self.tts_speed = tts_speed
290
+ self.qwen_api_url = qwen_api_url
291
+
292
+ self._init_asr_models()
293
+ self._init_tts_models()
294
+ self.translator = QwenTranslationAPI(api_url=qwen_api_url)
295
+ self._validate_files()
296
+
297
+ def _init_asr_models(self):
298
+ """初始化语音识别相关模型"""
299
+ print("Initializing SenseVoice models...")
300
+
301
+ self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
302
+
303
+ self.embed = SinusoidalPositionEncoder()
304
+ self.position_encoding = self.embed.get_position_encoding(
305
+ torch.randn(1, self.seq_len, 560)).numpy()
306
+
307
+ self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
308
+
309
+ tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
310
+ self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
311
+
312
+ print("SenseVoice models initialized successfully.")
313
+
314
+ def _init_tts_models(self):
315
+ """初始化TTS相关模型"""
316
+ print("Initializing MeloTTS models...")
317
+ init_start = time.time()
318
+ enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"])
319
+ dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"])
320
+
321
+ self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions())
322
+ self.sess_dec = axe.InferenceSession(dec_model)
323
+
324
+ g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"])
325
+ self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1)
326
+
327
+ self.tts_language = "ZH_MIX_EN"
328
+ self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])}
329
+
330
+ # 提前加载所有懒加载的模块(这是主要耗时部分)
331
+ print(" Warming up TTS modules (loading language models, tokenizers, etc.)...")
332
+ warmup_start = time.time()
333
+
334
+ # 中英
335
+ try:
336
+ warmup_start_mix = time.time()
337
+ warmup_text_mix = "这是一个test测试。"
338
+ _, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id)
339
+ print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms")
340
+ except Exception as e:
341
+ print(f" Warning: Mixed warm-up failed: {e}")
342
+
343
+ total_init_time = (time.time() - init_start) * 1000
344
+ print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)")
345
+
346
+ def _validate_files(self):
347
+ for key, filename in self.tts_model_files.items():
348
+ filepath = os.path.join(self.tts_model_dir, filename)
349
+ if not os.path.exists(filepath):
350
+ raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
351
+
352
+ try:
353
+ response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5)
354
+ print("[API检查] 千问API服务连接正常")
355
+ except:
356
+ print("[API警告] 无法连接到千问API服务")
357
+
358
+ def speech_recognition(self, speech, fs):
359
+ """第一步:语音识别(ASR)"""
360
+ speech_lengths = len(speech)
361
+
362
+ print("Running VAD...")
363
+ vad_start_time = time.time()
364
+ res_vad = self.model_vad(speech)[0]
365
+ vad_segments = merge_vad(res_vad, 15 * 1000)
366
+ vad_time_cost = time.time() - vad_start_time
367
+ print(f"VAD processing time: {vad_time_cost:.2f} seconds")
368
+ print(f"VAD segments detected: {len(vad_segments)}")
369
+
370
+ print("Running ASR...")
371
+ asr_start_time = time.time()
372
+ all_results = ""
373
+
374
+ for i, segment in enumerate(vad_segments):
375
+ segment_start, segment_end = segment
376
+ start_sample = int(segment_start / 1000 * fs)
377
+ end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
378
+ segment_speech = speech[start_sample:end_sample]
379
+
380
+ segment_filename = f"temp_segment_{i}.wav"
381
+ sf.write(segment_filename, segment_speech, fs)
382
+
383
+ try:
384
+ segment_res = self.model_bin(
385
+ segment_filename,
386
+ "auto",
387
+ True,
388
+ self.position_encoding,
389
+ tokenizer=self.tokenizer,
390
+ )
391
+
392
+ all_results += segment_res
393
+
394
+ if os.path.exists(segment_filename):
395
+ os.remove(segment_filename)
396
+
397
+ except Exception as e:
398
+ if os.path.exists(segment_filename):
399
+ os.remove(segment_filename)
400
+ print(f"Error processing segment {i}: {e}")
401
+ continue
402
+
403
+ asr_time_cost = time.time() - asr_start_time
404
+ print(f"ASR processing time: {asr_time_cost:.2f} seconds")
405
+ print(f"ASR Result: {all_results}")
406
+
407
+ return all_results.strip()
408
+
409
+ def run_translation(self, text_content):
410
+ """第二步:调用Qwen大模型API处理"""
411
+ print("Starting translation via API...")
412
+ translation_start_time = time.time()
413
+
414
+ translate_content = self.translator.translate(text_content)
415
+
416
+ translation_time_cost = time.time() - translation_start_time
417
+ print(f"Translation processing time: {translation_time_cost:.2f} seconds")
418
+ print(f"Translation Result: {translate_content}")
419
+
420
+ return translate_content
421
+
422
+ def run_tts(self, translate_content, output_dir, output_wav=None):
423
+ """第三步:使用TTS模型合成语音"""
424
+ output_path = os.path.join(output_dir, output_wav)
425
+
426
+ try:
427
+ if lang_detect_with_regex(translate_content) == "chinese":
428
+ translate_content = cn2an.transform(translate_content, "an2cn")
429
+
430
+ print(f"TTS synthesis for text: {translate_content}")
431
+
432
+ sens = split_sentence(translate_content, language_str=self.tts_language)
433
+ print(f"Text split into {len(sens)} sentences")
434
+
435
+ segments_dir = os.path.join(output_dir, "segments")
436
+ os.makedirs(segments_dir, exist_ok=True)
437
+
438
+ audio_list = []
439
+
440
+ for n, se in enumerate(sens):
441
+ if self.tts_language in ['EN', 'ZH_MIX_EN']:
442
+ se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se)
443
+
444
+ print(f"Processing sentence[{n}]: {se}")
445
+
446
+ phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer(
447
+ se, self.tts_language, symbol_to_id=self.symbol_to_id)
448
+
449
+ encoder_start = time.time()
450
+ z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={
451
+ 'phone': phones, 'g': self.tts_g,
452
+ 'tone': tones, 'language': lang_ids,
453
+ 'noise_scale': np.array([0], dtype=np.float32),
454
+ 'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32),
455
+ 'noise_scale_w': np.array([0], dtype=np.float32),
456
+ 'sdp_ratio': np.array([0], dtype=np.float32)})
457
+ encoder_time = time.time() - encoder_start
458
+ print(f"Encoder run time: {encoder_time*1000:.2f}ms")
459
+
460
+ word2pronoun = calc_word2pronoun(word2ph, pronoun_lens)
461
+ pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len)
462
+
463
+ audio_len = audio_len[0]
464
+ sub_audio_list = []
465
+
466
+ for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)):
467
+ zp_slice = z_p[..., zs]
468
+
469
+ sub_dec_len = zp_slice.shape[-1]
470
+ sub_audio_len = 512 * sub_dec_len
471
+
472
+ if zp_slice.shape[-1] < self.tts_dec_len:
473
+ zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1)
474
+
475
+ decoder_start = time.time()
476
+ audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten()
477
+
478
+ audio_start = 0
479
+ if len(sub_audio_list) > 0:
480
+ if pn_slices[i - 1].stop > ps.start:
481
+ audio_start = 512 * word2pronoun[ps.start]
482
+
483
+ audio_end = sub_audio_len
484
+ if i < len(pn_slices) - 1:
485
+ if ps.stop > pn_slices[i + 1].start:
486
+ audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1]
487
+
488
+ audio = audio[audio_start:audio_end]
489
+ sub_audio_list.append(audio)
490
+
491
+ merge_start = time.time()
492
+ sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len)
493
+ merge_time = time.time() - merge_start
494
+ print(f"Sentence[{n}] merge time: {merge_time*1000:.2f}ms")
495
+
496
+ output_wav_name = output_wav.split(".wav")[0]
497
+ segment_filename = os.path.join(segments_dir, f"{output_wav_name}_sentence_{n:03d}.wav")
498
+ sf.write(segment_filename, sub_audio, self.sample_rate)
499
+ print(f"Saved segment audio: {segment_filename}")
500
+
501
+ audio_list.append(sub_audio)
502
+
503
+ concat_start = time.time()
504
+ audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed)
505
+ concat_time = time.time() - concat_start
506
+ print(f"Audio concatenation time: {concat_time*1000:.2f}ms")
507
+
508
+ sf.write(output_path, audio, self.sample_rate)
509
+ print(f"TTS audio saved to {output_path}")
510
+
511
+ return output_path
512
+
513
+ except Exception as e:
514
+ print(f"TTS synthesis failed: {e}")
515
+ import traceback
516
+ traceback.print_exc()
517
+ raise e
518
+
519
+ def full_pipeline(self, speech, fs, output_dir=None, output_tts=None):
520
+ """完整Pipeline:语音识别 -> qwen -> TTS合成"""
521
+
522
+ print("\n----------------------VAD+ASR----------------------------\n")
523
+ start_time = time.time()
524
+ text_content = self.speech_recognition(speech, fs)
525
+ asr_time = time.time() - start_time
526
+ print(f"语音识别耗时: {asr_time:.2f} 秒")
527
+
528
+ if not text_content or text_content.strip() == "":
529
+ raise ValueError("ASR未能识别出有效文本")
530
+
531
+ print("\n---------------------Qwen---------------------------\n")
532
+ start_time = time.time()
533
+ translate_content = self.run_translation(text_content)
534
+ translate_time = time.time() - start_time
535
+ print(f"qwen耗时: {translate_time:.2f} 秒")
536
+
537
+ print("-------------------------TTS-------------------------------\n")
538
+ start_time = time.time()
539
+ output_path = self.run_tts(translate_content, output_dir, output_tts)
540
+ tts_time = time.time() - start_time
541
+ print(f"TTS合成耗时: {tts_time:.2f} 秒")
542
+
543
+ return {
544
+ "original_text": text_content,
545
+ "translated_text": translate_content,
546
+ "audio_path": output_path
547
+ }
548
+
549
+ def main():
550
+ parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline")
551
+ parser.add_argument("--audio_dir", type=str, default="./input_question", help="Input audio directory path")
552
+ parser.add_argument("--output_dir", type=str, default="./output_answer", help="Output directory")
553
+ parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API server URL")
554
+
555
+ args = parser.parse_args()
556
+ print("-------------------START------------------------\n")
557
+ os.makedirs(args.output_dir, exist_ok=True)
558
+
559
+ if not os.path.exists(args.audio_dir):
560
+ print(f"错误: 音频目录不存在: {args.audio_dir}")
561
+ return
562
+
563
+ audio_files = []
564
+ for file in os.listdir(args.audio_dir):
565
+ if file.lower().endswith(('.wav', '.mp3')):
566
+ audio_files.append(os.path.join(args.audio_dir, file))
567
+
568
+ if not audio_files:
569
+ print(f"错误: 在目录 {args.audio_dir} 中没有找到音频文件")
570
+ return
571
+
572
+ audio_files.sort()
573
+ print(f"找到 {len(audio_files)} 个音频文件: {[os.path.basename(f) for f in audio_files]}")
574
+
575
+ pipeline = SpeechTranslationPipeline(
576
+ tts_model_dir=TTS_MODEL_DIR,
577
+ tts_model_files=TTS_MODEL_FILES,
578
+ asr_model_dir="ax_model",
579
+ seq_len=132,
580
+ tts_dec_len=128,
581
+ sample_rate=44100,
582
+ tts_speed=0.8,
583
+ qwen_api_url=args.api_url
584
+ )
585
+
586
+ all_results = []
587
+ total_start_time = time.time()
588
+
589
+ for i, audio_file in enumerate(audio_files):
590
+ print(f"\n{'='*60}")
591
+ print(f"处理第 {i+1}/{len(audio_files)} 个音频文件: {os.path.basename(audio_file)}")
592
+ print(f"{'='*60}")
593
+
594
+ file_start_time = time.time()
595
+
596
+ try:
597
+ speech, fs = librosa.load(audio_file, sr=None)
598
+ if fs != 16000:
599
+ print(f"重采样音频从 {fs}Hz 到 16000Hz")
600
+ speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
601
+ fs = 16000
602
+ audio_duration = librosa.get_duration(y=speech, sr=fs)
603
+
604
+ base_name = os.path.splitext(os.path.basename(audio_file))[0]
605
+ output_tts = f"{base_name}_answer.wav"
606
+
607
+ result = pipeline.full_pipeline(speech, fs, args.output_dir, output_tts)
608
+
609
+ file_time_cost = time.time() - file_start_time
610
+
611
+ out_wav = os.path.join(args.output_dir,output_tts)
612
+ speech, fs = librosa.load(out_wav, sr=None)
613
+ output_duration = librosa.get_duration(y=speech, sr=fs)
614
+ rtf = file_time_cost / output_duration
615
+
616
+ result.update({
617
+ "audio_file": audio_file,
618
+ "processing_time": file_time_cost,
619
+ "output_duration": output_duration,
620
+ "rtf": rtf
621
+ })
622
+
623
+ all_results.append(result)
624
+
625
+ print(f"\n文件处理完成: {os.path.basename(audio_file)}")
626
+ print(f"原始文本: {result['original_text']}")
627
+ print(f"回答文本: {result['translated_text']}")
628
+ print(f"生成音频: {result['audio_path']}")
629
+ print(f"处理时间: {file_time_cost:.2f}")
630
+ print(f"音频时长: {output_duration:.2f} 秒")
631
+ print(f"RTF: {rtf:.2f}")
632
+
633
+ except Exception as e:
634
+ print(f"处理文件 {audio_file} 时出错: {e}")
635
+ import traceback
636
+ traceback.print_exc()
637
+ continue
638
+
639
+ total_time_cost = time.time() - total_start_time
640
+ print(f"\n{'='*80}")
641
+ print("所有文件处理完成!")
642
+ print(f"{'='*80}")
643
+ print(f"总共处理了 {len(all_results)} 个文件")
644
+ print(f"总处理时间: {total_time_cost:.2f} 秒")
645
+
646
+ summary_file = os.path.join(args.output_dir, "processing_summary.txt")
647
+ with open(summary_file, 'w', encoding='utf-8') as f:
648
+ f.write("批量处理结果汇总\n")
649
+ f.write("=" * 50 + "\n\n")
650
+
651
+ for i, result in enumerate(all_results):
652
+ f.write(f"文件 {i+1}: {os.path.basename(result['audio_file'])}\n")
653
+ f.write(f" 原始文本: {result['original_text']}\n")
654
+ f.write(f" 回答结果: {result['translated_text']}\n")
655
+ f.write(f" 合成音频: {os.path.basename(result['audio_path'])}\n")
656
+ f.write(f" 处理时间: {result['processing_time']:.2f} 秒\n")
657
+ f.write(f" 音频时长: {result['output_duration']:.2f} 秒\n")
658
+ f.write(f" RTF: {result['rtf']:.2f}\n")
659
+ f.write("-" * 50 + "\n")
660
+
661
+ f.write(f"\n总计: {len(all_results)} 个文件\n")
662
+ f.write(f"总处理时间: {total_time_cost:.2f} 秒\n")
663
+
664
+ print(f"详细结果已保存到: {summary_file}")
665
+
666
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  main()