Spaces:
Running on Zero
Running on Zero
- app.py +7 -0
- qwen_tts/inference/qwen3_tts_model.py +11 -3
app.py
CHANGED
|
@@ -251,6 +251,13 @@ def infer_voice_clone_from_prompt(part, language, prompt_file_path):
|
|
| 251 |
# 尝试作为单个对象处理
|
| 252 |
voice_clone_prompt = loaded_data
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
logger.info("音频特征文件加载成功。")
|
| 255 |
|
| 256 |
tts = load_model("Base", "0.6B")
|
|
|
|
| 251 |
# 尝试作为单个对象处理
|
| 252 |
voice_clone_prompt = loaded_data
|
| 253 |
|
| 254 |
+
# 维度校正:确保 ref_code 是 2D 的 (Time, Q)
|
| 255 |
+
if isinstance(voice_clone_prompt, list):
|
| 256 |
+
for item in voice_clone_prompt:
|
| 257 |
+
if item.ref_code is not None and item.ref_code.ndim == 3:
|
| 258 |
+
# [1, T, Q] -> [T, Q]
|
| 259 |
+
item.ref_code = item.ref_code.squeeze(0)
|
| 260 |
+
|
| 261 |
logger.info("音频特征文件加载成功。")
|
| 262 |
|
| 263 |
tts = load_model("Base", "0.6B")
|
qwen_tts/inference/qwen3_tts_model.py
CHANGED
|
@@ -628,10 +628,18 @@ class Qwen3TTSModel:
|
|
| 628 |
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 629 |
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 630 |
# 在 12Hz 模型中,Token 长度与时间成正比 (12 tokens/sec)
|
| 631 |
-
#
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
if codes.shape[0] > ref_len:
|
| 634 |
-
logger.info(f"检测到生成的 Token 序列包含引导部分,正在切除
|
| 635 |
processed_codes.append(codes[ref_len:])
|
| 636 |
else:
|
| 637 |
processed_codes.append(codes)
|
|
|
|
| 628 |
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 629 |
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 630 |
# 在 12Hz 模型中,Token 长度与时间成正比 (12 tokens/sec)
|
| 631 |
+
# 核心模型生成的 talker_codes 往往包含了与 Prompt 长度相当的引导部分
|
| 632 |
+
# 确保 ref_len 始终对应时间维度的长度
|
| 633 |
+
ref_item = ref_code_list[i]
|
| 634 |
+
if ref_item.ndim == 3: # [Batch, Time, Q]
|
| 635 |
+
ref_len = int(ref_item.shape[1])
|
| 636 |
+
elif ref_item.ndim == 2: # [Time, Q] 或 [Batch, Time]
|
| 637 |
+
ref_len = int(ref_item.shape[0])
|
| 638 |
+
else:
|
| 639 |
+
ref_len = int(ref_item.shape[0])
|
| 640 |
+
|
| 641 |
if codes.shape[0] > ref_len:
|
| 642 |
+
logger.info(f"检测到生成的 Token 序列包含引导部分 (长度 {ref_len}),正在执行切除")
|
| 643 |
processed_codes.append(codes[ref_len:])
|
| 644 |
else:
|
| 645 |
processed_codes.append(codes)
|