smartwang commited on
Commit
cc6e7cb
·
1 Parent(s): b0d24b1
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -12,6 +12,7 @@ import numpy as np
12
  import torch
13
  from huggingface_hub import snapshot_download, login
14
  from qwen_tts import Qwen3TTSModel
 
15
  import functools
16
  import uuid
17
  import random
@@ -238,7 +239,18 @@ def infer_voice_clone( part, language,audio_tuple,ref_text,use_xvector_only):
238
  def infer_voice_clone_from_prompt(part, language, prompt_file_path):
239
  """Single segment inference for Voice Clone using pre-extracted prompt."""
240
  logger.info("正在加载音频特征文件...")
241
- voice_clone_prompt = torch.load(prompt_file_path, map_location='cuda', weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
242
  logger.info("音频特征文件加载成功。")
243
 
244
  tts = load_model("Base", "0.6B")
@@ -289,19 +301,30 @@ def extract_voice_clone_prompt(ref_audio,ref_text,use_xvector_only):
289
  except Exception as e:
290
  logger.error(f"Whisper 识别失败: {str(e)}", exc_info=True)
291
 
292
- voice_clone_prompt = tts.create_voice_clone_prompt(
293
  ref_audio=audio_tuple,
294
  ref_text=r_text.strip() if r_text else None,
295
  x_vector_only_mode=uxo
296
  )
297
  logger.info("参考音频特征提取完成。")
298
 
 
 
 
 
 
 
 
 
 
 
 
299
  # 生成唯一的文件名
300
  file_id = str(uuid.uuid4())[:8]
301
  file_path = f"voice_clone_prompt_{file_id}.pt"
302
 
303
  # 保存到文件
304
- torch.save(voice_clone_prompt, file_path)
305
  logger.info(f"voice_clone_prompt 已保存到: {file_path}")
306
 
307
  return file_path
 
12
  import torch
13
  from huggingface_hub import snapshot_download, login
14
  from qwen_tts import Qwen3TTSModel
15
+ from qwen_tts.inference.qwen3_tts_model import VoiceClonePromptItem
16
  import functools
17
  import uuid
18
  import random
 
239
  def infer_voice_clone_from_prompt(part, language, prompt_file_path):
240
  """Single segment inference for Voice Clone using pre-extracted prompt."""
241
  logger.info("正在加载音频特征文件...")
242
+ loaded_data = torch.load(prompt_file_path, map_location='cuda', weights_only=False)
243
+
244
+ # 兼容旧版本直接保存对象的情况
245
+ if isinstance(loaded_data, list) and len(loaded_data) > 0 and isinstance(loaded_data[0], VoiceClonePromptItem):
246
+ voice_clone_prompt = loaded_data
247
+ elif isinstance(loaded_data, list) and len(loaded_data) > 0 and isinstance(loaded_data[0], dict):
248
+ # 从字典列表重建对象
249
+ voice_clone_prompt = [VoiceClonePromptItem(**item) for item in loaded_data]
250
+ else:
251
+ # 尝试作为单个对象处理
252
+ voice_clone_prompt = loaded_data
253
+
254
  logger.info("音频特征文件加载成功。")
255
 
256
  tts = load_model("Base", "0.6B")
 
301
  except Exception as e:
302
  logger.error(f"Whisper 识别失败: {str(e)}", exc_info=True)
303
 
304
+ voice_clone_prompt_items = tts.create_voice_clone_prompt(
305
  ref_audio=audio_tuple,
306
  ref_text=r_text.strip() if r_text else None,
307
  x_vector_only_mode=uxo
308
  )
309
  logger.info("参考音频特征提取完成。")
310
 
311
+ # 转换为字典列表保存,避免对象序列化问题
312
+ prompt_data = []
313
+ for item in voice_clone_prompt_items:
314
+ prompt_data.append({
315
+ "ref_code": item.ref_code,
316
+ "ref_spk_embedding": item.ref_spk_embedding,
317
+ "x_vector_only_mode": item.x_vector_only_mode,
318
+ "icl_mode": item.icl_mode,
319
+ "ref_text": item.ref_text
320
+ })
321
+
322
  # 生成唯一的文件名
323
  file_id = str(uuid.uuid4())[:8]
324
  file_path = f"voice_clone_prompt_{file_id}.pt"
325
 
326
  # 保存到文件
327
+ torch.save(prompt_data, file_path)
328
  logger.info(f"voice_clone_prompt 已保存到: {file_path}")
329
 
330
  return file_path