smartwang commited on
Commit
7e217de
·
1 Parent(s): 08f4100
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -73,17 +73,17 @@ base_model_0_6b = Qwen3TTSModel.from_pretrained(
73
  attn_implementation="kernels-community/flash-attn3",
74
  )
75
 
76
- @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
- def load_model(model_type, model_size):
78
- logger.info(f"正在按需加载 {model_type} {model_size} 模型...")
79
- path = get_model_path(model_type, model_size)
80
- return Qwen3TTSModel.from_pretrained(
81
- path,
82
- device_map="cuda", # 注意:在 ZeroGPU 环境下,这行只有在被装饰的函数内执行才有效
83
- dtype=torch.bfloat16,
84
- token=HF_TOKEN,
85
- attn_implementation="kernels-community/flash-attn3"
86
- )
87
 
88
  # logger.info("正在加载 Base 1.7B 模型...")
89
  # base_model_1_7b = Qwen3TTSModel.from_pretrained(
@@ -209,7 +209,7 @@ def split_text(text, max_len=30):
209
  @spaces.GPU
210
  def infer_voice_design(part, language, voice_description):
211
  """Single segment inference for Voice Design."""
212
- voice_design_model = load_model('VoiceDesign','1.7B')
213
  wavs, sr = voice_design_model.generate_voice_design(
214
  text=part,
215
  language=language,
@@ -224,7 +224,7 @@ def infer_voice_design(part, language, voice_description):
224
  def infer_voice_clone(model_size, part, language, voice_clone_prompt):
225
  """Single segment inference for Voice Clone."""
226
  # tts = BASE_MODELS[model_size]
227
- tts = load_model("Base", "0.6B")
228
  wavs, sr = tts.generate_voice_clone(
229
  text=part,
230
  language=language,
@@ -236,7 +236,7 @@ def infer_voice_clone(model_size, part, language, voice_clone_prompt):
236
  @spaces.GPU
237
  def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
238
  logger.info("正在提取参考音频特征(仅执行一次)...")
239
- tts = load_model("Base", "0.6B")
240
  voice_clone_prompt = tts.create_voice_clone_prompt(
241
  ref_audio=audio_tuple,
242
  ref_text=ref_text.strip() if ref_text else None,
 
73
  attn_implementation="kernels-community/flash-attn3",
74
  )
75
 
76
+ # @functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
77
+ # def load_model(model_type, model_size):
78
+ # logger.info(f"正在按需加载 {model_type} {model_size} 模型...")
79
+ # path = get_model_path(model_type, model_size)
80
+ # return Qwen3TTSModel.from_pretrained(
81
+ # path,
82
+ # device_map="cuda", # 注意:在 ZeroGPU 环境下,这行只有在被装饰的函数内执行才有效
83
+ # dtype=torch.bfloat16,
84
+ # token=HF_TOKEN,
85
+ # attn_implementation="kernels-community/flash-attn3"
86
+ # )
87
 
88
  # logger.info("正在加载 Base 1.7B 模型...")
89
  # base_model_1_7b = Qwen3TTSModel.from_pretrained(
 
209
  @spaces.GPU
210
  def infer_voice_design(part, language, voice_description):
211
  """Single segment inference for Voice Design."""
212
+ # voice_design_model = voice_design_model
213
  wavs, sr = voice_design_model.generate_voice_design(
214
  text=part,
215
  language=language,
 
224
  def infer_voice_clone(model_size, part, language, voice_clone_prompt):
225
  """Single segment inference for Voice Clone."""
226
  # tts = BASE_MODELS[model_size]
227
+ tts = base_model_0_6b
228
  wavs, sr = tts.generate_voice_clone(
229
  text=part,
230
  language=language,
 
236
  @spaces.GPU
237
  def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
238
  logger.info("正在提取参考音频特征(仅执行一次)...")
239
+ tts = base_model_0_6b
240
  voice_clone_prompt = tts.create_voice_clone_prompt(
241
  ref_audio=audio_tuple,
242
  ref_text=ref_text.strip() if ref_text else None,