smartwang commited on
Commit
fb11c6d
·
1 Parent(s): 7e217de
Files changed (1) hide show
  1. app.py +40 -36
app.py CHANGED
@@ -53,37 +53,37 @@ def get_model_path(model_type: str, model_size: str) -> str:
53
  # ============================================================================
54
  logger.info("正在加载所有模型到 CUDA...")
55
 
56
- # Voice Design model (1.7B only)
57
- logger.info("正在加载 VoiceDesign 1.7B 模型...")
58
- voice_design_model = Qwen3TTSModel.from_pretrained(
59
- get_model_path("VoiceDesign", "1.7B"),
60
- device_map="cuda",
61
- dtype=torch.bfloat16,
62
- token=HF_TOKEN,
63
- attn_implementation="kernels-community/flash-attn3",
64
- )
65
 
66
- # Base (Voice Clone) models - both sizes
67
- logger.info("正在加载 Base 0.6B 模型...")
68
- base_model_0_6b = Qwen3TTSModel.from_pretrained(
69
- get_model_path("Base", "0.6B"),
70
- device_map="cuda",
71
- dtype=torch.bfloat16,
72
- token=HF_TOKEN,
73
- attn_implementation="kernels-community/flash-attn3",
74
- )
75
 
76
- # @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
- # def load_model(model_type, model_size):
78
- # logger.info(f"正在按需加载 {model_type} {model_size} 模型...")
79
- # path = get_model_path(model_type, model_size)
80
- # return Qwen3TTSModel.from_pretrained(
81
- # path,
82
- # device_map="cuda", # 注意:在 ZeroGPU 环境下,这行只有在被装饰的函数内执行才有效
83
- # dtype=torch.bfloat16,
84
- # token=HF_TOKEN,
85
- # attn_implementation="kernels-community/flash-attn3"
86
- # )
87
 
88
  # logger.info("正在加载 Base 1.7B 模型...")
89
  # base_model_1_7b = Qwen3TTSModel.from_pretrained(
@@ -209,7 +209,8 @@ def split_text(text, max_len=30):
209
  @spaces.GPU
210
  def infer_voice_design(part, language, voice_description):
211
  """Single segment inference for Voice Design."""
212
- # voice_design_model = voice_design_model
 
213
  wavs, sr = voice_design_model.generate_voice_design(
214
  text=part,
215
  language=language,
@@ -221,10 +222,15 @@ def infer_voice_design(part, language, voice_description):
221
 
222
 
223
  @spaces.GPU
224
- def infer_voice_clone(model_size, part, language, voice_clone_prompt):
225
  """Single segment inference for Voice Clone."""
226
  # tts = BASE_MODELS[model_size]
227
- tts = base_model_0_6b
 
 
 
 
 
228
  wavs, sr = tts.generate_voice_clone(
229
  text=part,
230
  language=language,
@@ -233,10 +239,9 @@ def infer_voice_clone(model_size, part, language, voice_clone_prompt):
233
  )
234
  return wavs[0], sr
235
 
236
- @spaces.GPU
237
  def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
238
  logger.info("正在提取参考音频特征(仅执行一次)...")
239
- tts = base_model_0_6b
240
  voice_clone_prompt = tts.create_voice_clone_prompt(
241
  ref_audio=audio_tuple,
242
  ref_text=ref_text.strip() if ref_text else None,
@@ -301,7 +306,6 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
301
 
302
  logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
303
  try:
304
- voice_clone_prompt = extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only)
305
  text_parts = split_text(target_text.strip())
306
  logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
307
  all_wavs = []
@@ -309,7 +313,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
309
 
310
  for i, part in enumerate(progress.tqdm(text_parts, desc="正在生成分段")):
311
  logger.info(f"正在处理第 {i+1}/{len(text_parts)} 段文本...")
312
- wav, current_sr = infer_voice_clone(model_size, part, language, voice_clone_prompt)
313
  all_wavs.append(wav)
314
  sr = current_sr
315
 
 
53
  # ============================================================================
54
  logger.info("正在加载所有模型到 CUDA...")
55
 
56
+ # # Voice Design model (1.7B only)
57
+ # logger.info("正在加载 VoiceDesign 1.7B 模型...")
58
+ # voice_design_model = Qwen3TTSModel.from_pretrained(
59
+ # get_model_path("VoiceDesign", "1.7B"),
60
+ # device_map="cuda",
61
+ # dtype=torch.bfloat16,
62
+ # token=HF_TOKEN,
63
+ # attn_implementation="kernels-community/flash-attn3",
64
+ # )
65
 
66
+ # # Base (Voice Clone) models - both sizes
67
+ # logger.info("正在加载 Base 0.6B 模型...")
68
+ # base_model_0_6b = Qwen3TTSModel.from_pretrained(
69
+ # get_model_path("Base", "0.6B"),
70
+ # device_map="cuda",
71
+ # dtype=torch.bfloat16,
72
+ # token=HF_TOKEN,
73
+ # attn_implementation="kernels-community/flash-attn3",
74
+ # )
75
 
76
+ @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
+ def load_model(model_type, model_size):
78
+ logger.info(f"正在按需加载 {model_type} {model_size} 模型...")
79
+ path = get_model_path(model_type, model_size)
80
+ return Qwen3TTSModel.from_pretrained(
81
+ path,
82
+ device_map="cuda", # 注意:在 ZeroGPU 环境下,这行只有在被装饰的函数内执行才有效
83
+ dtype=torch.bfloat16,
84
+ token=HF_TOKEN,
85
+ attn_implementation="kernels-community/flash-attn3"
86
+ )
87
 
88
  # logger.info("正在加载 Base 1.7B 模型...")
89
  # base_model_1_7b = Qwen3TTSModel.from_pretrained(
 
209
  @spaces.GPU
210
  def infer_voice_design(part, language, voice_description):
211
  """Single segment inference for Voice Design."""
212
+ voice_design_model = load_model("VoiceDesign","1.7B")
213
+
214
  wavs, sr = voice_design_model.generate_voice_design(
215
  text=part,
216
  language=language,
 
222
 
223
 
224
  @spaces.GPU
225
+ def infer_voice_clone( part, language,audio_tuple,ref_text,use_xvector_only):
226
  """Single segment inference for Voice Clone."""
227
  # tts = BASE_MODELS[model_size]
228
+ tts = load_model("Base", "0.6B")
229
+ voice_clone_prompt = tts.create_voice_clone_prompt(
230
+ ref_audio=audio_tuple,
231
+ ref_text=ref_text.strip() if ref_text else None,
232
+ x_vector_only_mode=use_xvector_only
233
+ )
234
  wavs, sr = tts.generate_voice_clone(
235
  text=part,
236
  language=language,
 
239
  )
240
  return wavs[0], sr
241
 
 
242
  def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
243
  logger.info("正在提取参考音频特征(仅执行一次)...")
244
+ tts = load_model("Base", "0.6B")
245
  voice_clone_prompt = tts.create_voice_clone_prompt(
246
  ref_audio=audio_tuple,
247
  ref_text=ref_text.strip() if ref_text else None,
 
306
 
307
  logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
308
  try:
 
309
  text_parts = split_text(target_text.strip())
310
  logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
311
  all_wavs = []
 
313
 
314
  for i, part in enumerate(progress.tqdm(text_parts, desc="正在生成分段")):
315
  logger.info(f"正在处理第 {i+1}/{len(text_parts)} 段文本...")
316
+ wav, current_sr = infer_voice_clone( part, language,audio_tuple,ref_text,use_xvector_only)
317
  all_wavs.append(wav)
318
  sr = current_sr
319