yueyulin commited on
Commit
3d8f3be
·
verified ·
1 Parent(s): acb8f2a

Update rwkvtts-respark-webrwkv/tts_cli.py

Browse files
Files changed (1) hide show
  1. rwkvtts-respark-webrwkv/tts_cli.py +358 -178
rwkvtts-respark-webrwkv/tts_cli.py CHANGED
@@ -10,6 +10,7 @@ import sys
10
  import re
11
  import time
12
  import warnings
 
13
  from pathlib import Path
14
  from typing import Dict, Any, Tuple, List
15
 
@@ -17,6 +18,27 @@ import numpy as np
17
  import soundfile as sf
18
  import click
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # 抑制警告
21
  warnings.filterwarnings("ignore", category=UserWarning, module="numpy")
22
  warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime")
@@ -30,8 +52,8 @@ try:
30
  HAS_WEBRWKV = True
31
  except ImportError:
32
  HAS_WEBRWKV = False
33
- print("❌ 错误: 需要安装 'webrwkv_py' 库")
34
- print("请运行: pip install webrwkv_py")
35
  sys.exit(1)
36
 
37
  try:
@@ -39,8 +61,8 @@ try:
39
  HAS_ONNX = True
40
  except ImportError:
41
  HAS_ONNX = False
42
- print("❌ 错误: 需要安装 'onnxruntime' 库")
43
- print("请运行: pip install onnxruntime")
44
  sys.exit(1)
45
 
46
  try:
@@ -48,8 +70,8 @@ try:
48
  HAS_TRANSFORMERS = True
49
  except ImportError:
50
  HAS_TRANSFORMERS = False
51
- print("❌ 错误: 需要安装 'transformers' 库")
52
- print("请运行: pip install transformers")
53
  sys.exit(1)
54
 
55
  try:
@@ -57,8 +79,8 @@ try:
57
  HAS_QUESTIONARY = True
58
  except ImportError:
59
  HAS_QUESTIONARY = False
60
- print(" 错误: 需要安装 'questionary' 库来使用交互式界面")
61
- print("请运行: pip install questionary")
62
  sys.exit(1)
63
 
64
  # 导入属性工具
@@ -73,7 +95,7 @@ try:
73
  pitch_choices = list(PITCH_MAP.keys())
74
  speed_choices = list(SPEED_MAP.keys())
75
  except ImportError:
76
- print("⚠️ 警告: 无法导入 properties_util,使用默认选项")
77
  # 默认选项
78
  age_choices = ['child', 'teenager', 'youth-adult', 'middle-aged', 'elderly']
79
  gender_choices = ['female', 'male'] # 与properties_util.py保持一致
@@ -167,26 +189,26 @@ class TTSGenerator:
167
  self.model_path = model_path
168
 
169
  # 初始化 RefAudioUtilities 实例
170
- print('🎿 开始加载音频编码器模型')
171
  try:
172
  audio_tokenizer_path = os.path.join(model_path, 'BiCodecTokenize.onnx')
173
  wav2vec2_path = os.path.join(model_path, 'wav2vec2-large-xlsr-53.onnx')
174
  from ref_audio_utilities import RefAudioUtilities
175
  self.ref_audio_utilities = RefAudioUtilities(audio_tokenizer_path, wav2vec2_path)
176
- print('✅ 音频编码器模型加载成功')
177
  except Exception as e:
178
- print(f'❌ 音频编码器模型加载失败: {e}')
179
  self.ref_audio_utilities = None
180
 
181
  # 缓存ONNX session
182
- print('🎿 开始加载ONNX模型')
183
  try:
184
  self.ort_session = ort.InferenceSession(decoder_path,
185
  providers=['CUDAExecutionProvider','CPUExecutionProvider'])
186
- print(f"🖥️ONNX Session for generate wavform actual providers: {self.ort_session.get_providers()}")
187
- print('✅ ONNX模型加载成功')
188
  except Exception as e:
189
- print(f'❌ ONNX模型加载失败: {e}')
190
  raise
191
 
192
  # 生成统计信息
@@ -213,9 +235,9 @@ class TTSGenerator:
213
  """重置runtime状态"""
214
  try:
215
  self.runtime.reset()
216
- print("🔄 Runtime状态已重置")
217
  except Exception as e:
218
- print(f"⚠️ Runtime重置失败: {e}")
219
 
220
  def generate_audio(self, params: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
221
  """生成音频"""
@@ -233,15 +255,15 @@ class TTSGenerator:
233
  ref_audio_path = params['ref_audio_path']
234
  prompt_text = params.get('prompt_text', "希望你以后能够做的,比我还好呦!")
235
 
236
- print(f"🎯 开始生成音频 (Zero Shot 模式): {text}")
237
- print(f"📊 参数: 参考音频={ref_audio_path}, 提示文本={prompt_text}")
238
 
239
  # 检测语言
240
  lang = detect_token_lang(text)
241
- print(f"🌍 检测到语言: {lang}")
242
 
243
  # 使用 zero shot 方法生成 tokens
244
- global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed = self._generate_tokens_zeroshot(text, ref_audio_path, prompt_text)
245
  else:
246
  # 传统模式
247
  age = params['age']
@@ -250,46 +272,29 @@ class TTSGenerator:
250
  pitch = params['pitch']
251
  speed = params['speed']
252
 
253
- print(f"🎯 开始生成音频: {text}")
254
- print(f"📊 参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}")
255
 
256
  # 检测语言
257
  lang = detect_token_lang(text)
258
- print(f"🌍 检测到语言: {lang}")
259
 
260
  # 生成global tokens和semantic tokens
261
- global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed = self._generate_tokens(text, age, gender, emotion, pitch, speed)
 
262
 
263
  # 解码音频
264
- print("🎵 解码音频...")
265
- decode_start = time.time()
266
 
267
- # 准备输入数据 - 按照tts_gui_simple.py的逻辑
268
- print("🔧 准备解码器输入数据...")
269
- global_tokens_array = np.array(global_tokens, dtype=np.int64).reshape(1, 1, -1)
270
- semantic_tokens_array = np.array(semantic_tokens, dtype=np.int64).reshape(1, -1)
271
- print(f'🎯 生成的全局token: {global_tokens}')
272
- print(f'🎯 生成的语义token: {semantic_tokens}')
273
- print(f'📊 解码器输入形状: global_tokens={global_tokens_array.shape}, semantic_tokens={semantic_tokens_array.shape}')
274
 
275
- # 使用ONNX解码器生成音频
276
- print("🎵 开始ONNX解码器推理...")
277
- outputs = self.ort_session.run(None, {
278
- "global_tokens": global_tokens_array,
279
- "semantic_tokens": semantic_tokens_array
280
- })
281
- wav_data = outputs[0].reshape(-1)
282
- decode_time = time.time() - decode_start
283
-
284
- # 计算音频时长和RTF
285
- audio_duration = len(wav_data) / 16000 # 采样率16kHz
286
- decode_speed = len(semantic_tokens) / decode_time if decode_time > 0 else 0
287
  total_time = time.time() - start_time
288
  total_tokens = len(global_tokens) + len(semantic_tokens)
289
  rtf = total_time / audio_duration if audio_duration > 0 else 0
290
 
291
- print(f" 音频解码完成,时长 {audio_duration:.2f}s,耗时 {decode_time:.2f}s,速度 {decode_speed:.1f} tokens/s")
292
- print(f"📊 总耗时: {total_time:.2f}s,RTF: {rtf:.2f}")
293
 
294
  # 更新统计信息
295
  self.generation_stats['total_generations'] += 1
@@ -303,7 +308,6 @@ class TTSGenerator:
303
  'total_tokens': total_tokens,
304
  'audio_duration': audio_duration,
305
  'rtf': rtf,
306
- 'global_speed': global_speed,
307
  'semantic_speed': semantic_speed,
308
  'decode_speed': decode_speed,
309
  'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
@@ -312,7 +316,7 @@ class TTSGenerator:
312
 
313
  return wav_data, self.generation_stats['last_generation']
314
 
315
- def _generate_tokens(self, text: str, age: str, gender: str, emotion: str, pitch: str, speed: str) -> Tuple[List[int], List[int], float, float, float, float]:
316
  """
317
  生成global tokens和semantic tokens
318
 
@@ -323,17 +327,17 @@ class TTSGenerator:
323
  emotion: 情感参数
324
  pitch: 音高参数
325
  speed: 速度参数
326
-
327
  Returns:
328
  Tuple: (global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed)
329
  """
330
  # 编码文本
331
- print("🔤 编码文本...")
332
  tokens = self.tokenizer.encode(text)
333
- print(f"✅ 文本编码完成,共 {len(tokens)} 个token")
334
 
335
  # 生成全局token
336
- print("🌐 生成全局token...")
337
  global_start = time.time()
338
 
339
  # 准备输入tokens
@@ -344,7 +348,7 @@ class TTSGenerator:
344
  # 构建属性tokens - 使用properties_util.py
345
  from properties_util import convert_standard_properties_to_tokens
346
  properties_text = convert_standard_properties_to_tokens(age, gender, emotion, pitch, speed)
347
- print(f'🔤 属性文本: {properties_text}')
348
  properties_tokens = self.tokenizer.encode(properties_text, add_special_tokens=False)
349
  properties_tokens = [i + 8196 + 4096 for i in properties_tokens]
350
 
@@ -352,36 +356,70 @@ class TTSGenerator:
352
  text_tokens = [i + 8196 + 4096 for i in tokens]
353
 
354
  # 组合所有tokens
355
- all_idx = properties_tokens + [TTS_TAG_2] + text_tokens + [TTS_TAG_0]
356
- print(f'🔢 属性token: {properties_tokens}')
357
- print(f'🔢 文本token: {text_tokens}')
358
- print(f'🎯 组合后的tokens: {all_idx}')
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  # Prefill阶段
361
- print("💎 开始Prefill阶段...")
362
- logits = self.runtime.predict(all_idx)
363
- print(f"✅ Prefill完成,logits长度: {len(logits)}")
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  # 生成全局token - 按照tts_gui_simple.py的逻辑
366
- print("🌍 开始生成全局token...")
367
- global_tokens_size = 32
368
- global_tokens = []
369
-
370
- for i in range(global_tokens_size):
371
- # 从logits中采样token
372
- sampled_id = sample_logits(logits[0:4096], temperature=1.0, top_p=0.95, top_k=20)
373
- global_tokens.append(sampled_id)
374
- # 预测下一个token
375
- sampled_id += 8196
376
- logits = self.runtime.predict_next(sampled_id)
377
-
378
- global_time = time.time() - global_start
379
- global_speed = global_tokens_size / global_time if global_time > 0 else 0
380
- print(f"✅ 全局token生成完成,共 {len(global_tokens)} 个token,耗时 {global_time:.2f}s,速度 {global_speed:.1f} tokens/s")
381
- print(f'🎯 生成的全局token: {global_tokens}')
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  # 生成语义token
384
- print("🧠 生成语义token...")
385
  semantic_start = time.time()
386
 
387
  # 按照tts_gui_simple.py的逻辑生成语义token
@@ -391,17 +429,78 @@ class TTSGenerator:
391
  for i in range(2048): # 最大生成2048个token
392
  sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
393
  if sampled_id == 8192: # 遇到结束标记
394
- print(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
395
  break
396
  semantic_tokens.append(sampled_id)
397
  x = self.runtime.predict_next(sampled_id)
398
 
399
  semantic_time = time.time() - semantic_start
400
  semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
401
- print(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
402
 
403
  return global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  def _generate_tokens_zeroshot(self, text: str, ref_audio_path: str, prompt_text: str = "希望你以后能够做的,比我还好呦!") -> Tuple[List[int], List[int], float, float, float, float]:
406
  """
407
  使用 zero shot 方式生成global tokens和semantic tokens
@@ -418,26 +517,24 @@ class TTSGenerator:
418
  raise RuntimeError("RefAudioUtilities 未初始化,无法使用 zero shot 模式")
419
 
420
  # 编码文本
421
- print("🔤 编码文本...")
422
  text_tokens = self.tokenizer.encode(prompt_text + text, add_special_tokens=False)
423
  text_tokens = [i + 8196 + 4096 for i in text_tokens]
424
- print(f"✅ 文本编码完成,共 {len(text_tokens)} 个token")
425
 
426
  # 从参考音频获取 global tokens 和 semantic tokens
427
- print("🎵 处理参考音频...")
428
  global_tokens, prompt_semantic_tokens = self.ref_audio_utilities.tokenize(ref_audio_path)
429
- print(f"✅ 参考音频处理完成")
430
 
431
  # 直接使用flatten()展平数组并转换为Python一维数组
432
  global_tokens = [int(i) + 8196 for i in global_tokens.flatten()]
433
  prompt_semantic_tokens = [int(i) for i in prompt_semantic_tokens.flatten()]
434
 
435
- print(f'🎯 参考音频 global_tokens: {global_tokens}')
436
- print(f'🎯 参考音频 semantic_tokens: {prompt_semantic_tokens}')
 
437
 
438
- # 生成全局token
439
- print("🌐 生成全局token...")
440
- global_start = time.time()
441
 
442
  # 准备输入tokens
443
  TTS_TAG_0 = 8193
@@ -446,19 +543,27 @@ class TTSGenerator:
446
 
447
  # 组合所有tokens
448
  all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1] + prompt_semantic_tokens
449
- print(f'🎯 组合后的tokens: {all_idx}')
450
 
451
  # Prefill阶段
452
- print("💎 开始Prefill阶段...")
453
- logits = self.runtime.predict(all_idx)
454
- print(f"✅ Prefill完成,logits长度: {len(logits)}")
455
-
456
- global_time = time.time() - global_start
457
- global_speed = len(global_tokens) / global_time if global_time > 0 else 0
458
- print(f"✅ 全局token处理完成,共 {len(global_tokens)} 个token,耗时 {global_time:.2f}s,速度 {global_speed:.1f} tokens/s")
 
 
 
 
 
 
 
 
459
 
460
  # 生成语义token
461
- print("🧠 生成语义token...")
462
  semantic_start = time.time()
463
 
464
  # 从当前logits开始生成语义token
@@ -468,52 +573,110 @@ class TTSGenerator:
468
  for i in range(2048): # 最大生成2048个token
469
  sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
470
  if sampled_id == 8192: # 遇到结束标记
471
- print(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
472
  break
473
  semantic_tokens.append(sampled_id)
474
  x = self.runtime.predict_next(sampled_id)
475
 
476
  semantic_time = time.time() - semantic_start
477
  semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
478
- print(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
479
 
480
  global_tokens = [i - 8196 for i in global_tokens]
481
- return global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  def display_stats(stats: Dict[str, Any]):
484
  """显示生成统计信息"""
485
- print("\n" + "="*60)
486
- print("📊 生成统计信息")
487
- print("="*60)
488
 
489
  if stats['text']:
490
- print(f"🎯 生成参数: {stats['params']}")
491
- print(f"📝 文本: {stats['text']}")
492
- print(f"⏱️ 总耗时: {stats['total_time']:.2f}s")
493
- print(f"🎵 音频时长: {stats['audio_duration']:.2f}s")
494
- print(f"📈 RTF: {stats['rtf']:.2f}")
495
- print(f"🔢 总token数: {stats['total_tokens']}")
496
- print(f"🌐 全局token速度: {stats['global_speed']:.1f} tokens/s")
497
- print(f"🧠 语义token速度: {stats['semantic_speed']:.1f} tokens/s")
498
- print(f"🎵 解码速度: {stats['decode_speed']:.1f} tokens/s")
499
- print(f"🕐 时间: {stats['timestamp']}")
500
  if stats['output_path']:
501
- print(f"💾 保存路径: {stats['output_path']}")
502
  else:
503
- print("暂无生成记录")
504
 
505
- print("="*60)
506
 
507
  def interactive_parameter_selection(generator: TTSGenerator):
508
  """交互式参数选择界面"""
509
- print("\n🎮 进入交互式配置界面")
510
- print("💡 使用方向键选择,回车确认,Ctrl+C退出")
511
 
512
  while True:
513
  try:
514
- print("\n" + "="*60)
515
- print("🎵 RWKV TTS 参数配置")
516
- print("="*60)
517
 
518
  # 选择生成模式
519
  generation_mode = questionary.select(
@@ -596,16 +759,18 @@ def interactive_parameter_selection(generator: TTSGenerator):
596
  output_path = get_unique_filename(output_dir, text)
597
 
598
  # 保存音频
599
- sf.write(output_path, wav_data, 16000)
600
- stats['output_path'] = output_path
 
 
601
 
602
- print(f"✅ 音频生成成功,保存至: {output_path}")
603
  stats['生成参数'] = f'参考音频={ref_audio_path}, 提示文本={prompt_text}'
604
  # 显示统计信息
605
  display_stats(stats)
606
 
607
  except Exception as e:
608
- print(f"❌ 生成失败: {e}")
609
  import traceback
610
  traceback.print_exc()
611
  else:
@@ -659,6 +824,20 @@ def interactive_parameter_selection(generator: TTSGenerator):
659
 
660
  if speed is None:
661
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
 
664
  # 确认生成
@@ -666,7 +845,8 @@ def interactive_parameter_selection(generator: TTSGenerator):
666
  f"🚀 确认生成音频?\n"
667
  f"文本: {text}\n"
668
  f"参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}\n"
669
- f"输出目录: {output_dir}",
 
670
  default=True
671
  ).ask()
672
 
@@ -680,7 +860,8 @@ def interactive_parameter_selection(generator: TTSGenerator):
680
  'emotion': emotion,
681
  'pitch': pitch,
682
  'speed': speed,
683
- 'output_dir': output_dir
 
684
  }
685
 
686
  # 生成音频
@@ -691,16 +872,18 @@ def interactive_parameter_selection(generator: TTSGenerator):
691
  output_path = get_unique_filename(output_dir, text)
692
 
693
  # 保存音频
694
- sf.write(output_path, wav_data, 16000)
695
- stats['output_path'] = output_path
 
 
696
 
697
- print(f"✅ 音频生成成功,保存至: {output_path}")
698
  stats['生成参数'] = f'年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}'
699
  # 显示统计信息
700
  display_stats(stats)
701
 
702
  except Exception as e:
703
- print(f"❌ 生成失败: {e}")
704
  import traceback
705
  traceback.print_exc()
706
 
@@ -714,61 +897,57 @@ def interactive_parameter_selection(generator: TTSGenerator):
714
  break
715
 
716
  except KeyboardInterrupt:
717
- print("\n👋 用户中断,退出程序")
718
  break
719
  except Exception as e:
720
- print(f"❌ 发生错误: {e}")
721
  import traceback
722
  traceback.print_exc()
723
  break
724
 
725
- print("👋 感谢使用 RWKV TTS!")
726
 
727
  @click.command()
728
  @click.option('--model_path', required=True, help='RWKV模型路径')
729
  def main(model_path):
730
  """RWKV TTS 主程序"""
731
- print("🚀 欢迎使用 RWKV TTS 交互式音频生成工具!")
732
 
733
  # 检查模型文件
734
  if not os.path.exists(model_path):
735
- print(f"❌ 错误: 模型路径不存在: {model_path}")
736
  return
737
 
738
  # 自动构建解码器路径
739
  decoder_path = os.path.join(model_path, "BiCodecDetokenize.onnx")
740
- print(f"🔍 自动设置解码器路径: {decoder_path}")
741
 
742
  # 检查模型目录中的文件
743
- print(f"🔍 检查模型目录: {model_path}")
744
  try:
745
  model_files = os.listdir(model_path)
746
- print(f"📁 模型目录中的文件:")
747
  for file in model_files:
748
  file_path = os.path.join(model_path, file)
749
  if os.path.isfile(file_path):
750
  size = os.path.getsize(file_path)
751
- print(f" 📄 {file} ({size:,} bytes)")
752
  else:
753
- print(f" 📁 {file}/")
754
  except Exception as e:
755
- print(f"⚠️ 无法列出模型目录内容: {e}")
756
 
757
  if not os.path.exists(decoder_path):
758
- print(f"❌ 错误: 解码器路径不存在: {decoder_path}")
759
  return
760
 
761
  # 选择设备
762
- print("\n💎 选择设备 💎")
763
  try:
764
  devices = webrwkv_py.get_available_adapters_py()
765
- except AttributeError:
766
- # 如果新API不存在,尝试旧API
767
- try:
768
- devices = webrwkv_py.get_available_devices()
769
- except AttributeError:
770
- print("❌ 无法获取可用设备列表")
771
- return
772
 
773
  for i, device in enumerate(devices):
774
  print(f"{i}: {device}")
@@ -777,16 +956,16 @@ def main(model_path):
777
  try:
778
  device_idx = int(device_choice)
779
  if device_idx < 0 or device_idx >= len(devices):
780
- print("❌ 无效的设备选择")
781
  return
782
  device = devices[device_idx]
783
- print(f"✅ 选择设备: {device}")
784
  except ValueError:
785
- print("❌ 无效的设备选择")
786
  return
787
 
788
  # 加载模型
789
- print("\n💎 加载模型 💎")
790
  try:
791
  # 尝试多种可能的模型文件名
792
  possible_model_files = [
@@ -798,55 +977,56 @@ def main(model_path):
798
  test_path = os.path.join(model_path, model_file)
799
  if os.path.exists(test_path):
800
  webrwkv_model_path = test_path
801
- print(f"✅ 找到模型文件: {model_file}")
802
  break
803
 
804
  if webrwkv_model_path is None:
805
- print(f"❌ 未找到模型文件")
806
- print(f"💡 请检查模型目录 {model_path} 中是否包含以下文件之一:")
807
  for model_file in possible_model_files:
808
- print(f" - {model_file}")
809
  return
810
 
811
- print(f"🔍 尝试加载模型文件: {webrwkv_model_path}")
812
 
813
  # 尝试新的API
814
  model = webrwkv_py.Model(webrwkv_model_path, 'fp32', device_idx)
815
- print(f"✅ 模型加载成功: {webrwkv_model_path}")
816
  except Exception as e:
817
- print(f"❌ 模型加载失败: {e}")
818
- print(f"💡 请检查:")
819
- print(f" 1. 模型文件路径是否正确: {webrwkv_model_path}")
820
- print(f" 2. 模型文件是否完整")
821
- print(f" 3. 设备索引是否正确: {device_idx}")
822
- print(f" 4. 模型文件格式是否支持")
823
  return
824
 
825
  # 创建runtime
826
- print("\n💎 创建 runtime 💎")
827
  try:
828
  runtime = model.create_thread_runtime()
829
- print("✅ runtime 创建成功")
830
  except Exception as e:
831
- print(f"❌ runtime 创建失败: {e}")
832
  return
833
 
834
  # 加载tokenizer
835
- print("\n💎 加载 tokenizer 💎")
836
  try:
837
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
838
- print(f"✅ tokenizer 加载成功: {model_path}")
839
  except Exception as e:
840
- print(f"❌ tokenizer 加载失败: {e}")
841
- print(f"💡 请检查模型目录 {model_path} 中是否包含正确的tokenizer文件")
842
  return
843
 
844
  # 创建TTS生成器
845
  generator = TTSGenerator(runtime, tokenizer, decoder_path, device, model_path)
846
 
847
  # 启动交互式界面
848
- print("\n🎯 启动交互式配置界面...")
849
  interactive_parameter_selection(generator)
850
 
851
  if __name__ == "__main__":
852
  main()
 
 
10
  import re
11
  import time
12
  import warnings
13
+ import logging
14
  from pathlib import Path
15
  from typing import Dict, Any, Tuple, List
16
 
 
18
  import soundfile as sf
19
  import click
20
 
21
+ generated_global_tokens = {}
22
+
23
+ # 配置日志
24
+ def setup_logging():
25
+ """设置日志配置"""
26
+ # 从环境变量获取日志级别,默认为WARNING
27
+ log_level_str = os.environ.get('LOG_LEVEL', 'INFO').upper()
28
+ log_level = getattr(logging, log_level_str, logging.WARNING)
29
+
30
+ # 配置日志格式
31
+ logging.basicConfig(
32
+ level=log_level,
33
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
34
+ datefmt='%Y-%m-%d %H:%M:%S'
35
+ )
36
+
37
+ return logging.getLogger(__name__)
38
+
39
+ # 创建logger实例
40
+ logger = setup_logging()
41
+
42
  # 抑制警告
43
  warnings.filterwarnings("ignore", category=UserWarning, module="numpy")
44
  warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime")
 
52
  HAS_WEBRWKV = True
53
  except ImportError:
54
  HAS_WEBRWKV = False
55
+ logger.error("❌ 错误: 需要安装 'webrwkv_py' 库")
56
+ logger.error("请运行: pip install webrwkv_py")
57
  sys.exit(1)
58
 
59
  try:
 
61
  HAS_ONNX = True
62
  except ImportError:
63
  HAS_ONNX = False
64
+ logger.error("❌ 错误: 需要安装 'onnxruntime' 库")
65
+ logger.error("请运行: pip install onnxruntime")
66
  sys.exit(1)
67
 
68
  try:
 
70
  HAS_TRANSFORMERS = True
71
  except ImportError:
72
  HAS_TRANSFORMERS = False
73
+ logger.error("❌ 错误: 需要安装 'transformers' 库")
74
+ logger.error("请运行: pip install transformers")
75
  sys.exit(1)
76
 
77
  try:
 
79
  HAS_QUESTIONARY = True
80
  except ImportError:
81
  HAS_QUESTIONARY = False
82
+ logger.warning("⚠️ 警告: 无法导入 questionary 库来使用交互式界面")
83
+ logger.warning("请运行: pip install questionary")
84
  sys.exit(1)
85
 
86
  # 导入属性工具
 
95
  pitch_choices = list(PITCH_MAP.keys())
96
  speed_choices = list(SPEED_MAP.keys())
97
  except ImportError:
98
+ logger.warning("⚠️ 警告: 无法导入 properties_util,使用默认选项")
99
  # 默认选项
100
  age_choices = ['child', 'teenager', 'youth-adult', 'middle-aged', 'elderly']
101
  gender_choices = ['female', 'male'] # 与properties_util.py保持一致
 
189
  self.model_path = model_path
190
 
191
  # 初始化 RefAudioUtilities 实例
192
+ logger.info('🎿 开始加载音频编码器模型')
193
  try:
194
  audio_tokenizer_path = os.path.join(model_path, 'BiCodecTokenize.onnx')
195
  wav2vec2_path = os.path.join(model_path, 'wav2vec2-large-xlsr-53.onnx')
196
  from ref_audio_utilities import RefAudioUtilities
197
  self.ref_audio_utilities = RefAudioUtilities(audio_tokenizer_path, wav2vec2_path)
198
+ logger.info('✅ 音频编码器模型加载成功')
199
  except Exception as e:
200
+ logger.error(f'❌ 音频编码器模型加载失败: {e}')
201
  self.ref_audio_utilities = None
202
 
203
  # 缓存ONNX session
204
+ logger.info('🎿 开始加载ONNX模型')
205
  try:
206
  self.ort_session = ort.InferenceSession(decoder_path,
207
  providers=['CUDAExecutionProvider','CPUExecutionProvider'])
208
+ logger.info(f"🖥️ONNX Session for generate wavform actual providers: {self.ort_session.get_providers()}")
209
+ logger.info('✅ ONNX模型加载成功')
210
  except Exception as e:
211
+ logger.error(f'❌ ONNX模型加载失败: {e}')
212
  raise
213
 
214
  # 生成统计信息
 
235
  """重置runtime状态"""
236
  try:
237
  self.runtime.reset()
238
+ logger.info("🔄 Runtime状态已重置")
239
  except Exception as e:
240
+ logger.warning(f"⚠️ Runtime重置失败: {e}")
241
 
242
  def generate_audio(self, params: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
243
  """生成音频"""
 
255
  ref_audio_path = params['ref_audio_path']
256
  prompt_text = params.get('prompt_text', "希望你以后能够做的,比我还好呦!")
257
 
258
+ logger.info(f"🎯 开始生成音频 (Zero Shot 模式): {text}")
259
+ logger.info(f"📊 参数: 参考音频={ref_audio_path}, 提示文本={prompt_text}")
260
 
261
  # 检测语言
262
  lang = detect_token_lang(text)
263
+ logger.info(f"🌍 检测到语言: {lang}")
264
 
265
  # 使用 zero shot 方法生成 tokens
266
+ global_tokens, semantic_tokens, semantic_time, semantic_speed = self._generate_tokens_zeroshot(text, ref_audio_path, prompt_text)
267
  else:
268
  # 传统模式
269
  age = params['age']
 
272
  pitch = params['pitch']
273
  speed = params['speed']
274
 
275
+ logger.info(f"🎯 开始生成音频: {text}")
276
+ logger.info(f"📊 参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}")
277
 
278
  # 检测语言
279
  lang = detect_token_lang(text)
280
+ logger.info(f"🌍 检测到语言: {lang}")
281
 
282
  # 生成global tokens和semantic tokens
283
+ generated_key = params['generated_key']
284
+ global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed = self._generate_tokens(text, age, gender, emotion, pitch, speed, generated_key)
285
 
286
  # 解码音频
287
+ logger.info("🎵 解码音频...")
 
288
 
289
+ # 使用抽象化的音频解码函数
290
+ wav_data, audio_duration, decode_time, decode_speed = self._decode_audio(global_tokens, semantic_tokens)
 
 
 
 
 
291
 
292
+ # 计算总耗时和RTF
 
 
 
 
 
 
 
 
 
 
 
293
  total_time = time.time() - start_time
294
  total_tokens = len(global_tokens) + len(semantic_tokens)
295
  rtf = total_time / audio_duration if audio_duration > 0 else 0
296
 
297
+ logger.info(f"📊 总耗时: {total_time:.2f}s,RTF: {rtf:.2f}")
 
298
 
299
  # 更新统计信息
300
  self.generation_stats['total_generations'] += 1
 
308
  'total_tokens': total_tokens,
309
  'audio_duration': audio_duration,
310
  'rtf': rtf,
 
311
  'semantic_speed': semantic_speed,
312
  'decode_speed': decode_speed,
313
  'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
 
316
 
317
  return wav_data, self.generation_stats['last_generation']
318
 
319
+ def _generate_tokens(self, text: str, age: str, gender: str, emotion: str, pitch: str, speed: str, generated_key: str = None) -> Tuple[List[int], List[int], float, float, float, float]:
320
  """
321
  生成global tokens和semantic tokens
322
 
 
327
  emotion: 情感参数
328
  pitch: 音高参数
329
  speed: 速度参数
330
+ generated_key: 之前生成的全局token的key
331
  Returns:
332
  Tuple: (global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed)
333
  """
334
  # 编码文本
335
+ logger.info("🔤 编码文本...")
336
  tokens = self.tokenizer.encode(text)
337
+ logger.info(f"✅ 文本编码完成,共 {len(tokens)} 个token")
338
 
339
  # 生成全局token
340
+ logger.info("🌐 生成全局token...")
341
  global_start = time.time()
342
 
343
  # 准备输入tokens
 
348
  # 构建属性tokens - 使用properties_util.py
349
  from properties_util import convert_standard_properties_to_tokens
350
  properties_text = convert_standard_properties_to_tokens(age, gender, emotion, pitch, speed)
351
+ logger.info(f'🔤 属性文本: {properties_text}')
352
  properties_tokens = self.tokenizer.encode(properties_text, add_special_tokens=False)
353
  properties_tokens = [i + 8196 + 4096 for i in properties_tokens]
354
 
 
356
  text_tokens = [i + 8196 + 4096 for i in tokens]
357
 
358
  # 组合所有tokens
359
+ if generated_key is None or generated_key not in generated_global_tokens:
360
+ all_idx = properties_tokens + [TTS_TAG_2] + text_tokens + [TTS_TAG_0]
361
+ else:
362
+ logger.info(f"🎯 使用之前生成的全局token: {generated_key}")
363
+ previous_global_tokens = generated_global_tokens[generated_key]
364
+ global_tokens = previous_global_tokens.copy()
365
+ global_time = 0
366
+ global_speed = 0
367
+ logger.info(f"🎯 使用之前生成的全局token: {previous_global_tokens}")
368
+ previous_global_tokens = [int(i) + 8196 for i in previous_global_tokens]
369
+ logger.info(f"🎯 偏移后的全局token: {previous_global_tokens}")
370
+ all_idx = properties_tokens + [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + previous_global_tokens
371
+ logger.info(f'🔢 属性token: {properties_tokens}')
372
+ logger.info(f'🔢 文本token: {text_tokens}')
373
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
374
 
375
  # Prefill阶段
376
+ logger.info("💎 开始Prefill阶段...")
377
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
378
+ step_count = 0
379
+ start = time.time()
380
+ while not session.is_complete():
381
+ step_count += 1
382
+ output = session.step()
383
+ if not output.batches[0].is_empty():
384
+ logits = output.batches[0].data
385
+ break
386
+
387
+ prefill_time = time.time() - start
388
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
389
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
390
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s {len(all_idx)/prefill_time:.1f} tokens/s")
391
 
392
  # 生成全局token - 按照tts_gui_simple.py的逻辑
393
+ if generated_key is None or generated_key not in generated_global_tokens:
394
+ logger.info("🌍 开始生成全局token...")
395
+ global_tokens_size = 32
396
+ global_tokens = []
397
+
398
+ for i in range(global_tokens_size):
399
+ # logits中采样token
400
+ sampled_id = sample_logits(logits[0:4096], temperature=1.0, top_p=0.95, top_k=20)
401
+ global_tokens.append(sampled_id)
402
+ # 预测下一个token
403
+ sampled_id += 8196
404
+ logits = self.runtime.predict_next(sampled_id)
405
+
406
+ global_time = time.time() - global_start
407
+ global_speed = global_tokens_size / global_time if global_time > 0 else 0
408
+ logger.info(f"✅ 全局token生成完成,共 {len(global_tokens)} 个token,耗时 {global_time:.2f}s,速度 {global_speed:.1f} tokens/s")
409
+ logger.info(f'🎯 生成的全局token: {global_tokens}')
410
+ prefix = f"{age}_{gender}_{pitch}_{emotion}_{speed}"
411
+ key = f"{prefix}_0"
412
+ if key in generated_global_tokens:
413
+ #found the latest index of the same key
414
+ latest_index = max([int(k.split('_')[-1]) for k in generated_global_tokens.keys() if k.startswith(prefix)])
415
+ key = f"{prefix}_{latest_index + 1}"
416
+ generated_global_tokens[key] = global_tokens
417
+ logger.info(f'🎯 生成的全局token: {generated_global_tokens[key]}, 下次可以调用generated_global_tokens[{key}]')
418
+
419
+
420
 
421
  # 生成语义token
422
+ logger.info("🧠 生成语义token...")
423
  semantic_start = time.time()
424
 
425
  # 按照tts_gui_simple.py的逻辑生成语义token
 
429
  for i in range(2048): # 最大生成2048个token
430
  sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
431
  if sampled_id == 8192: # 遇到结束标记
432
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
433
  break
434
  semantic_tokens.append(sampled_id)
435
  x = self.runtime.predict_next(sampled_id)
436
 
437
  semantic_time = time.time() - semantic_start
438
  semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
439
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
440
 
441
  return global_tokens, semantic_tokens, global_time, global_speed, semantic_time, semantic_speed
442
+
443
+ def _generate_tokens_with_global_tokens(self, text: str, global_tokens: List[int]) -> Tuple[List[int], List[int], float, float, float, float]:
444
+ """
445
+ 使用 global tokens 生成语义token
446
+ """
447
+ # 编码文本
448
+ logger.info("🔤 编码文本...")
449
+ text_tokens = self.tokenizer.encode(text, add_special_tokens=False)
450
+ text_tokens = [i + 8196 + 4096 for i in text_tokens]
451
+ logger.info(f"✅ 文本编码完成,共 {len(text_tokens)} 个token")
452
+ global_tokens = [int(i) + 8196 for i in global_tokens]
453
+ logger.info(f'🎯 参考音频 global_tokens: {global_tokens}')
454
+ start = time.time()
455
+
456
+ # 准备输入tokens
457
+ TTS_TAG_0 = 8193
458
+ TTS_TAG_1 = 8194
459
+ TTS_TAG_2 = 8195
460
+
461
+ # 组合所有tokens
462
+ all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1]
463
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
464
+
465
+ # Prefill阶段
466
+ logger.info("💎 开始Prefill阶段...")
467
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
468
+ step_count = 0
469
+ while not session.is_complete():
470
+ step_count += 1
471
+ output = session.step()
472
+ if not output.batches[0].is_empty():
473
+ logits = output.batches[0].data[0]
474
+ break
475
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
476
+ logger.info(f"✅ Prefill完成,速度 {step_count/output.time:.1f} tokens/s")
477
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
478
+ prefill_time = time.time() - start
479
+ prefill_speed = len(all_idx) / prefill_time if prefill_time > 0 else 0
480
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s,速度 {prefill_speed:.1f} tokens/s")
481
+
482
+ # 生成语义token
483
+ logger.info("🧠 生成语义token...")
484
+ semantic_start = time.time()
485
 
486
+ # 从当前logits开始生成语义token
487
+ x = logits
488
+ semantic_tokens = []
489
+
490
+ for i in range(2048): # 最大生成2048个token
491
+ sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
492
+ if sampled_id == 8192: # 遇到结束标记
493
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
494
+ break
495
+ semantic_tokens.append(sampled_id)
496
+ x = self.runtime.predict_next(sampled_id)
497
+
498
+ semantic_time = time.time() - semantic_start
499
+ semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
500
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
501
+
502
+ return global_tokens, semantic_tokens, prefill_time, prefill_speed, semantic_time, semantic_speed
503
+
504
  def _generate_tokens_zeroshot(self, text: str, ref_audio_path: str, prompt_text: str = "希望你以后能够做的,比我还好呦!") -> Tuple[List[int], List[int], float, float, float, float]:
505
  """
506
  使用 zero shot 方式生成global tokens和semantic tokens
 
517
  raise RuntimeError("RefAudioUtilities 未初始化,无法使用 zero shot 模式")
518
 
519
  # 编码文本
520
+ logger.info("🔤 编码文本...")
521
  text_tokens = self.tokenizer.encode(prompt_text + text, add_special_tokens=False)
522
  text_tokens = [i + 8196 + 4096 for i in text_tokens]
523
+ logger.info(f"✅ 文本编码完成,共 {len(text_tokens)} 个token")
524
 
525
  # 从参考音频获取 global tokens 和 semantic tokens
526
+ logger.info("🎵 处理参考音频...")
527
  global_tokens, prompt_semantic_tokens = self.ref_audio_utilities.tokenize(ref_audio_path)
528
+ logger.info(f"✅ 参考音频处理完成")
529
 
530
  # 直接使用flatten()展平数组并转换为Python一维数组
531
  global_tokens = [int(i) + 8196 for i in global_tokens.flatten()]
532
  prompt_semantic_tokens = [int(i) for i in prompt_semantic_tokens.flatten()]
533
 
534
+ logger.info(f'🎯 参考音频 global_tokens: {global_tokens}')
535
+ logger.info(f'🎯 参考音频 semantic_tokens: {prompt_semantic_tokens}')
536
+
537
 
 
 
 
538
 
539
  # 准备输入tokens
540
  TTS_TAG_0 = 8193
 
543
 
544
  # 组合所有tokens
545
  all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1] + prompt_semantic_tokens
546
+ logger.info(f'🎯 组合后的tokens: {all_idx}')
547
 
548
  # Prefill阶段
549
+ logger.info("💎 开始Prefill阶段...")
550
+ session = self.runtime.create_inference_session([all_idx],token_chunk_size=512)
551
+ step_count = 0
552
+ start = time.time()
553
+ while not session.is_complete():
554
+ step_count += 1
555
+ output = session.step()
556
+ if not output.batches[0].is_empty():
557
+ logits = output.batches[0].data
558
+ break
559
+ prefill_time = time.time() - start
560
+ logger.info(f"✅ Prefill完成,logits长度: {len(logits)}")
561
+ logger.info(f"✅ Prefill完成,耗时 {step_count} 步")
562
+ logger.info(f"✅ Prefill完成,耗时 {prefill_time:.2f}s {len(all_idx)/prefill_time:.1f} tokens/s")
563
+
564
 
565
  # 生成语义token
566
+ logger.info("🧠 生成语义token...")
567
  semantic_start = time.time()
568
 
569
  # 从当前logits开始生成语义token
 
573
  for i in range(2048): # 最大生成2048个token
574
  sampled_id = sample_logits(x[0:8193], temperature=1.0, top_p=0.95, top_k=80)
575
  if sampled_id == 8192: # 遇到结束标记
576
+ logger.info(f"🛑 语义token生成结束,遇到结束标记,共生成 {len(semantic_tokens)} 个token")
577
  break
578
  semantic_tokens.append(sampled_id)
579
  x = self.runtime.predict_next(sampled_id)
580
 
581
  semantic_time = time.time() - semantic_start
582
  semantic_speed = len(semantic_tokens) / semantic_time if semantic_time > 0 else 0
583
+ logger.info(f"✅ 语义token生成完成,共 {len(semantic_tokens)} 个token,耗时 {semantic_time:.2f}s,速度 {semantic_speed:.1f} tokens/s")
584
 
585
  global_tokens = [i - 8196 for i in global_tokens]
586
+ return global_tokens, semantic_tokens, semantic_time, semantic_speed
587
+
588
+ def _decode_audio(self, global_tokens: List[int], semantic_tokens: List[int]) -> Tuple[np.ndarray, float, float, float]:
589
+ """
590
+ 解码音频的核心函数
591
+
592
+ Args:
593
+ global_tokens: 全局tokens列表
594
+ semantic_tokens: 语义tokens列表
595
+
596
+ Returns:
597
+ Tuple: (wav_data, audio_duration, decode_time, decode_speed)
598
+ """
599
+ # 开始计时
600
+ decode_start = time.time()
601
+
602
+ # 准备输入数据
603
+ logger.info("🔧 准备解码器输入数据...")
604
+ global_tokens_array = np.array(global_tokens, dtype=np.int64).reshape(1, 1, -1)
605
+ semantic_tokens_array = np.array(semantic_tokens, dtype=np.int64).reshape(1, -1)
606
+ logger.info(f'🎯 生成的全局token: {global_tokens}')
607
+ logger.info(f'🎯 生成的语义token: {semantic_tokens}')
608
+ logger.info(f'📊 解码器输入形状: global_tokens={global_tokens_array.shape}, semantic_tokens={semantic_tokens_array.shape}')
609
+
610
+ # 使用ONNX解码器生成音频
611
+ logger.info("🎵 开始ONNX解码器推理...")
612
+ outputs = self.ort_session.run(None, {
613
+ "global_tokens": global_tokens_array,
614
+ "semantic_tokens": semantic_tokens_array
615
+ })
616
+ wav_data = outputs[0].reshape(-1)
617
+ decode_time = time.time() - decode_start
618
+
619
+ # 计算音频时长和解码速度
620
+ audio_duration = len(wav_data) / 16000 # 采样率16kHz
621
+ decode_speed = len(semantic_tokens) / decode_time if decode_time > 0 else 0
622
+
623
+ logger.info(f"✅ 音频解码完成,时长 {audio_duration:.2f}s,耗时 {decode_time:.2f}s,速度 {decode_speed:.1f} tokens/s")
624
+
625
+ return wav_data, audio_duration, decode_time, decode_speed
626
+
627
+ def _save_audio(self, wav_data: np.ndarray, output_path: str, sample_rate: int = 16000) -> bool:
628
+ """
629
+ 保存音频文件
630
+
631
+ Args:
632
+ wav_data: 音频数据
633
+ output_path: 输出文件路径
634
+ sample_rate: 采样率,默认16kHz
635
+
636
+ Returns:
637
+ bool: 保存是否成功
638
+ """
639
+ try:
640
+ sf.write(output_path, wav_data, sample_rate)
641
+ logger.info(f"💾 音频保存成功: {output_path}")
642
+ return True
643
+ except Exception as e:
644
+ logger.error(f"❌ 音频保存失败: {e}")
645
+ return False
646
 
647
  def display_stats(stats: Dict[str, Any]):
648
  """显示生成统计信息"""
649
+ logger.info("\n" + "="*60)
650
+ logger.info("📊 生成统计信息")
651
+ logger.info("="*60)
652
 
653
  if stats['text']:
654
+ logger.info(f"🎯 生成参数: {stats['params']}")
655
+ logger.info(f"📝 文本: {stats['text']}")
656
+ logger.info(f"⏱️ 总耗时: {stats['total_time']:.2f}s")
657
+ logger.info(f"🎵 音频时长: {stats['audio_duration']:.2f}s")
658
+ logger.info(f"📈 RTF: {stats['rtf']:.2f}")
659
+ logger.info(f"🔢 总token数: {stats['total_tokens']}")
660
+ logger.info(f"🧠 语义token速度: {stats['semantic_speed']:.1f} tokens/s")
661
+ logger.info(f"🎵 解码速度: {stats['decode_speed']:.1f} tokens/s")
662
+ logger.info(f"🕐 时间: {stats['timestamp']}")
 
663
  if stats['output_path']:
664
+ logger.info(f"💾 保存路径: {stats['output_path']}")
665
  else:
666
+ logger.info("暂无生成记录")
667
 
668
+ logger.info("="*60)
669
 
670
  def interactive_parameter_selection(generator: TTSGenerator):
671
  """交互式参数选择界面"""
672
+ logger.info("\n🎮 进入交互式配置界面")
673
+ logger.info("💡 使用方向键选择,回车确认,Ctrl+C退出")
674
 
675
  while True:
676
  try:
677
+ logger.info("\n" + "="*60)
678
+ logger.info("🎵 RWKV TTS 参数配置")
679
+ logger.info("="*60)
680
 
681
  # 选择生成模式
682
  generation_mode = questionary.select(
 
759
  output_path = get_unique_filename(output_dir, text)
760
 
761
  # 保存音频
762
+ if generator._save_audio(wav_data, output_path, 16000):
763
+ stats['output_path'] = output_path
764
+ else:
765
+ logger.warning("⚠️ 音频保存失败,但生成统计已更新")
766
 
767
+ logger.info(f"✅ 音频生成成功,保存至: {output_path}")
768
  stats['生成参数'] = f'参考音频={ref_audio_path}, 提示文本={prompt_text}'
769
  # 显示统计信息
770
  display_stats(stats)
771
 
772
  except Exception as e:
773
+ logger.error(f"❌ 生成失败: {e}")
774
  import traceback
775
  traceback.print_exc()
776
  else:
 
824
 
825
  if speed is None:
826
  break
827
+ prefix = f"{age}_{gender}"
828
+ list_of_generated_keys = []
829
+ for generated_key in generated_global_tokens.keys():
830
+ if generated_key.startswith(prefix):
831
+ list_of_generated_keys.append(generated_key)
832
+ if len(list_of_generated_keys) > 0:
833
+ list_of_generated_keys.append("None")
834
+ generated_key = questionary.select(
835
+ "🎯 是否使用之前生成的全局token?",
836
+ choices=list_of_generated_keys,
837
+ default="None"
838
+ ).ask()
839
+ else:
840
+ generated_key = None
841
 
842
 
843
  # 确认生成
 
845
  f"🚀 确认生成音频?\n"
846
  f"文本: {text}\n"
847
  f"参数: 年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}\n"
848
+ f"输出目录: {output_dir}\n"
849
+ f"是否使用之前生成的全局token: {generated_key is not None}",
850
  default=True
851
  ).ask()
852
 
 
860
  'emotion': emotion,
861
  'pitch': pitch,
862
  'speed': speed,
863
+ 'output_dir': output_dir,
864
+ 'generated_key': generated_key
865
  }
866
 
867
  # 生成音频
 
872
  output_path = get_unique_filename(output_dir, text)
873
 
874
  # 保存音频
875
+ if generator._save_audio(wav_data, output_path, 16000):
876
+ stats['output_path'] = output_path
877
+ else:
878
+ logger.warning("⚠️ 音频保存失败,但生成统计已更新")
879
 
880
+ logger.info(f"✅ 音频生成成功,保存至: {output_path}")
881
  stats['生成参数'] = f'年龄={age}, 性别={gender}, 情感={emotion}, 音高={pitch}, 速度={speed}'
882
  # 显示统计信息
883
  display_stats(stats)
884
 
885
  except Exception as e:
886
+ logger.error(f"❌ 生成失败: {e}")
887
  import traceback
888
  traceback.print_exc()
889
 
 
897
  break
898
 
899
  except KeyboardInterrupt:
900
+ logger.info("\n👋 用户中断,退出程序")
901
  break
902
  except Exception as e:
903
+ logger.error(f"❌ 发生错误: {e}")
904
  import traceback
905
  traceback.print_exc()
906
  break
907
 
908
+ logger.info("👋 感谢使用 RWKV TTS!")
909
 
910
  @click.command()
911
  @click.option('--model_path', required=True, help='RWKV模型路径')
912
  def main(model_path):
913
  """RWKV TTS 主程序"""
914
+ logger.info("🚀 欢迎使用 RWKV TTS 交互式音频生成工具!")
915
 
916
  # 检查模型文件
917
  if not os.path.exists(model_path):
918
+ logger.error(f"❌ 错误: 模型路径不存在: {model_path}")
919
  return
920
 
921
  # 自动构建解码器路径
922
  decoder_path = os.path.join(model_path, "BiCodecDetokenize.onnx")
923
+ logger.info(f"🔍 自动设置解码器路径: {decoder_path}")
924
 
925
  # 检查模型目录中的文件
926
+ logger.info(f"🔍 检查模型目录: {model_path}")
927
  try:
928
  model_files = os.listdir(model_path)
929
+ logger.info(f"📁 模型目录中的文件:")
930
  for file in model_files:
931
  file_path = os.path.join(model_path, file)
932
  if os.path.isfile(file_path):
933
  size = os.path.getsize(file_path)
934
+ logger.info(f" 📄 {file} ({size:,} bytes)")
935
  else:
936
+ logger.info(f" 📁 {file}/")
937
  except Exception as e:
938
+ logger.warning(f"⚠️ 无法列出模型目录内容: {e}")
939
 
940
  if not os.path.exists(decoder_path):
941
+ logger.error(f"❌ 错误: 解码器路径不存在: {decoder_path}")
942
  return
943
 
944
  # 选择设备
945
+ logger.info("\n💎 选择设备 💎")
946
  try:
947
  devices = webrwkv_py.get_available_adapters_py()
948
+ except Exception as e:
949
+ logger.error(f"❌ 无法获取可用设备列表: {e}")
950
+ return
 
 
 
 
951
 
952
  for i, device in enumerate(devices):
953
  print(f"{i}: {device}")
 
956
  try:
957
  device_idx = int(device_choice)
958
  if device_idx < 0 or device_idx >= len(devices):
959
+ logger.error("❌ 无效的设备选择")
960
  return
961
  device = devices[device_idx]
962
+ logger.info(f"✅ 选择设备: {device}")
963
  except ValueError:
964
+ logger.error("❌ 无效的设备选择")
965
  return
966
 
967
  # 加载模型
968
+ logger.info("\n💎 加载模型 💎")
969
  try:
970
  # 尝试多种可能的模型文件名
971
  possible_model_files = [
 
977
  test_path = os.path.join(model_path, model_file)
978
  if os.path.exists(test_path):
979
  webrwkv_model_path = test_path
980
+ logger.info(f"✅ 找到模型文件: {model_file}")
981
  break
982
 
983
  if webrwkv_model_path is None:
984
+ logger.error(f"❌ 未找到模型文件")
985
+ logger.info(f"💡 请检查模型目录 {model_path} 中是否包含以下文件之一:")
986
  for model_file in possible_model_files:
987
+ logger.info(f" - {model_file}")
988
  return
989
 
990
+ logger.info(f"🔍 尝试加载模型文件: {webrwkv_model_path}")
991
 
992
  # 尝试新的API
993
  model = webrwkv_py.Model(webrwkv_model_path, 'fp32', device_idx)
994
+ logger.info(f"✅ 模型加载成功: {webrwkv_model_path}")
995
  except Exception as e:
996
+ logger.error(f"❌ 模型加载失败: {e}")
997
+ logger.info(f"💡 请检查:")
998
+ logger.info(f" 1. 模型文件路径是否正确: {webrwkv_model_path}")
999
+ logger.info(f" 2. 模型文件是否完整")
1000
+ logger.info(f" 3. 设备索引是否正确: {device_idx}")
1001
+ logger.info(f" 4. 模型文件格式是否支持")
1002
  return
1003
 
1004
  # 创建runtime
1005
+ logger.info("\n💎 创建 runtime 💎")
1006
  try:
1007
  runtime = model.create_thread_runtime()
1008
+ logger.info("✅ runtime 创建成功")
1009
  except Exception as e:
1010
+ logger.error(f"❌ runtime 创建失败: {e}")
1011
  return
1012
 
1013
  # 加载tokenizer
1014
+ logger.info("\n💎 加载 tokenizer 💎")
1015
  try:
1016
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
1017
+ logger.info(f"✅ tokenizer 加载成功: {model_path}")
1018
  except Exception as e:
1019
+ logger.error(f"❌ tokenizer 加载失败: {e}")
1020
+ logger.info(f"💡 请检查模型目录 {model_path} 中是否包含正确的tokenizer文件")
1021
  return
1022
 
1023
  # 创建TTS生成器
1024
  generator = TTSGenerator(runtime, tokenizer, decoder_path, device, model_path)
1025
 
1026
  # 启动交互式界面
1027
+ logger.info("\n🎯 启动交互式配置界面...")
1028
  interactive_parameter_selection(generator)
1029
 
1030
  if __name__ == "__main__":
1031
  main()
1032
+