smartwang commited on
Commit
607ce9b
·
1 Parent(s): 9a86e48
Files changed (2) hide show
  1. app.py +7 -0
  2. qwen_tts/inference/qwen3_tts_model.py +11 -3
app.py CHANGED
@@ -251,6 +251,13 @@ def infer_voice_clone_from_prompt(part, language, prompt_file_path):
251
  # 尝试作为单个对象处理
252
  voice_clone_prompt = loaded_data
253
 
 
 
 
 
 
 
 
254
  logger.info("音频特征文件加载成功。")
255
 
256
  tts = load_model("Base", "0.6B")
 
251
  # 尝试作为单个对象处理
252
  voice_clone_prompt = loaded_data
253
 
254
+ # 维度校正:确保 ref_code 是 2D 的 (Time, Q)
255
+ if isinstance(voice_clone_prompt, list):
256
+ for item in voice_clone_prompt:
257
+ if item.ref_code is not None and item.ref_code.ndim == 3:
258
+ # [1, T, Q] -> [T, Q]
259
+ item.ref_code = item.ref_code.squeeze(0)
260
+
261
  logger.info("音频特征文件加载成功。")
262
 
263
  tts = load_model("Base", "0.6B")
qwen_tts/inference/qwen3_tts_model.py CHANGED
@@ -628,10 +628,18 @@ class Qwen3TTSModel:
628
  ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
629
  if ref_code_list is not None and ref_code_list[i] is not None:
630
  # 在 12Hz 模型中,Token 长度与时间成正比 (12 tokens/sec)
631
- # 经验观察表明生成的结果中包含了一段与 Prompt 长度相当的引导部分
632
- ref_len = int(ref_code_list[i].shape[0])
 
 
 
 
 
 
 
 
633
  if codes.shape[0] > ref_len:
634
- logger.info(f"检测到生成的 Token 序列包含引导部分,正在切除前 {ref_len} 个 Token")
635
  processed_codes.append(codes[ref_len:])
636
  else:
637
  processed_codes.append(codes)
 
628
  ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
629
  if ref_code_list is not None and ref_code_list[i] is not None:
630
  # 在 12Hz 模型中,Token 长度与时间成正比 (12 tokens/sec)
631
+ # 核心模型生成的 talker_codes 往往包含了与 Prompt 长度相当的引导部分
632
+ # 确保 ref_len 始终对应时间维度的长度
633
+ ref_item = ref_code_list[i]
634
+ if ref_item.ndim == 3: # [Batch, Time, Q]
635
+ ref_len = int(ref_item.shape[1])
636
+ elif ref_item.ndim == 2: # [Time, Q] 或 [Batch, Time]
637
+ ref_len = int(ref_item.shape[0])
638
+ else:
639
+ ref_len = int(ref_item.shape[0])
640
+
641
  if codes.shape[0] > ref_len:
642
+ logger.info(f"检测到生成的 Token 序列包含引导部分 (长度 {ref_len}),正在执行切除")
643
  processed_codes.append(codes[ref_len:])
644
  else:
645
  processed_codes.append(codes)