smartwang commited on
Commit
08f4100
·
1 Parent(s): 3309e4b
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -54,24 +54,24 @@ def get_model_path(model_type: str, model_size: str) -> str:
54
  logger.info("正在加载所有模型到 CUDA...")
55
 
56
  # Voice Design model (1.7B only)
57
- # logger.info("正在加载 VoiceDesign 1.7B 模型...")
58
- # voice_design_model = Qwen3TTSModel.from_pretrained(
59
- # get_model_path("VoiceDesign", "1.7B"),
60
- # device_map="cuda",
61
- # dtype=torch.bfloat16,
62
- # token=HF_TOKEN,
63
- # attn_implementation="kernels-community/flash-attn3",
64
- # )
65
 
66
  # Base (Voice Clone) models - both sizes
67
- # logger.info("正在加载 Base 0.6B 模型...")
68
- # base_model_0_6b = Qwen3TTSModel.from_pretrained(
69
- # get_model_path("Base", "0.6B"),
70
- # device_map="cuda",
71
- # dtype=torch.bfloat16,
72
- # token=HF_TOKEN,
73
- # attn_implementation="kernels-community/flash-attn3",
74
- # )
75
 
76
  @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
  def load_model(model_type, model_size):
@@ -233,7 +233,17 @@ def infer_voice_clone(model_size, part, language, voice_clone_prompt):
233
  )
234
  return wavs[0], sr
235
 
236
-
 
 
 
 
 
 
 
 
 
 
237
  # @spaces.GPU(duration=60)
238
  # def infer_custom_voice(model_size, part, language, speaker, instruct):
239
  # """Single segment inference for Custom Voice."""
@@ -291,16 +301,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
291
 
292
  logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
293
  try:
294
- # 优化:在循环外提取参考音频特征,避免重复处理
295
- logger.info("正在提取参考音频特征(仅执行一次)...")
296
- tts = load_model("Base", "0.6B")
297
- voice_clone_prompt = tts.create_voice_clone_prompt(
298
- ref_audio=audio_tuple,
299
- ref_text=ref_text.strip() if ref_text else None,
300
- x_vector_only_mode=use_xvector_only
301
- )
302
- logger.info("参考音频特征提取完成。")
303
-
304
  text_parts = split_text(target_text.strip())
305
  logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
306
  all_wavs = []
 
54
  logger.info("正在加载所有模型到 CUDA...")
55
 
56
  # Voice Design model (1.7B only)
57
+ logger.info("正在加载 VoiceDesign 1.7B 模型...")
58
+ voice_design_model = Qwen3TTSModel.from_pretrained(
59
+ get_model_path("VoiceDesign", "1.7B"),
60
+ device_map="cuda",
61
+ dtype=torch.bfloat16,
62
+ token=HF_TOKEN,
63
+ attn_implementation="kernels-community/flash-attn3",
64
+ )
65
 
66
  # Base (Voice Clone) models - both sizes
67
+ logger.info("正在加载 Base 0.6B 模型...")
68
+ base_model_0_6b = Qwen3TTSModel.from_pretrained(
69
+ get_model_path("Base", "0.6B"),
70
+ device_map="cuda",
71
+ dtype=torch.bfloat16,
72
+ token=HF_TOKEN,
73
+ attn_implementation="kernels-community/flash-attn3",
74
+ )
75
 
76
  @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
  def load_model(model_type, model_size):
 
233
  )
234
  return wavs[0], sr
235
 
236
+ @spaces.GPU
237
+ def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
238
+ logger.info("正在提取参考音频特征(仅执行一次)...")
239
+ tts = load_model("Base", "0.6B")
240
+ voice_clone_prompt = tts.create_voice_clone_prompt(
241
+ ref_audio=audio_tuple,
242
+ ref_text=ref_text.strip() if ref_text else None,
243
+ x_vector_only_mode=use_xvector_only
244
+ )
245
+ logger.info("参考音频特征提取完成。")
246
+ return voice_clone_prompt
247
  # @spaces.GPU(duration=60)
248
  # def infer_custom_voice(model_size, part, language, speaker, instruct):
249
  # """Single segment inference for Custom Voice."""
 
301
 
302
  logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
303
  try:
304
+ voice_clone_prompt = extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only)
 
 
 
 
 
 
 
 
 
305
  text_parts = split_text(target_text.strip())
306
  logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
307
  all_wavs = []