Spaces:
Running on Zero
Running on Zero
app.py
CHANGED
|
@@ -53,37 +53,37 @@ def get_model_path(model_type: str, model_size: str) -> str:
|
|
| 53 |
# ============================================================================
|
| 54 |
logger.info("正在加载所有模型到 CUDA...")
|
| 55 |
|
| 56 |
-
# Voice Design model (1.7B only)
|
| 57 |
-
logger.info("正在加载 VoiceDesign 1.7B 模型...")
|
| 58 |
-
voice_design_model = Qwen3TTSModel.from_pretrained(
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
)
|
| 65 |
|
| 66 |
-
# Base (Voice Clone) models - both sizes
|
| 67 |
-
logger.info("正在加载 Base 0.6B 模型...")
|
| 68 |
-
base_model_0_6b = Qwen3TTSModel.from_pretrained(
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
)
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
# logger.info("正在加载 Base 1.7B 模型...")
|
| 89 |
# base_model_1_7b = Qwen3TTSModel.from_pretrained(
|
|
@@ -209,7 +209,8 @@ def split_text(text, max_len=30):
|
|
| 209 |
@spaces.GPU
|
| 210 |
def infer_voice_design(part, language, voice_description):
|
| 211 |
"""Single segment inference for Voice Design."""
|
| 212 |
-
|
|
|
|
| 213 |
wavs, sr = voice_design_model.generate_voice_design(
|
| 214 |
text=part,
|
| 215 |
language=language,
|
|
@@ -221,10 +222,15 @@ def infer_voice_design(part, language, voice_description):
|
|
| 221 |
|
| 222 |
|
| 223 |
@spaces.GPU
|
| 224 |
-
def infer_voice_clone(
|
| 225 |
"""Single segment inference for Voice Clone."""
|
| 226 |
# tts = BASE_MODELS[model_size]
|
| 227 |
-
tts =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
wavs, sr = tts.generate_voice_clone(
|
| 229 |
text=part,
|
| 230 |
language=language,
|
|
@@ -233,10 +239,9 @@ def infer_voice_clone(model_size, part, language, voice_clone_prompt):
|
|
| 233 |
)
|
| 234 |
return wavs[0], sr
|
| 235 |
|
| 236 |
-
@spaces.GPU
|
| 237 |
def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
|
| 238 |
logger.info("正在提取参考音频特征(仅执行一次)...")
|
| 239 |
-
tts =
|
| 240 |
voice_clone_prompt = tts.create_voice_clone_prompt(
|
| 241 |
ref_audio=audio_tuple,
|
| 242 |
ref_text=ref_text.strip() if ref_text else None,
|
|
@@ -301,7 +306,6 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
|
|
| 301 |
|
| 302 |
logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
|
| 303 |
try:
|
| 304 |
-
voice_clone_prompt = extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only)
|
| 305 |
text_parts = split_text(target_text.strip())
|
| 306 |
logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
|
| 307 |
all_wavs = []
|
|
@@ -309,7 +313,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
|
|
| 309 |
|
| 310 |
for i, part in enumerate(progress.tqdm(text_parts, desc="正在生成分段")):
|
| 311 |
logger.info(f"正在处理第 {i+1}/{len(text_parts)} 段文本...")
|
| 312 |
-
wav, current_sr = infer_voice_clone(
|
| 313 |
all_wavs.append(wav)
|
| 314 |
sr = current_sr
|
| 315 |
|
|
|
|
| 53 |
# ============================================================================
|
| 54 |
logger.info("正在加载所有模型到 CUDA...")
|
| 55 |
|
| 56 |
+
# # Voice Design model (1.7B only)
|
| 57 |
+
# logger.info("正在加载 VoiceDesign 1.7B 模型...")
|
| 58 |
+
# voice_design_model = Qwen3TTSModel.from_pretrained(
|
| 59 |
+
# get_model_path("VoiceDesign", "1.7B"),
|
| 60 |
+
# device_map="cuda",
|
| 61 |
+
# dtype=torch.bfloat16,
|
| 62 |
+
# token=HF_TOKEN,
|
| 63 |
+
# attn_implementation="kernels-community/flash-attn3",
|
| 64 |
+
# )
|
| 65 |
|
| 66 |
+
# # Base (Voice Clone) models - both sizes
|
| 67 |
+
# logger.info("正在加载 Base 0.6B 模型...")
|
| 68 |
+
# base_model_0_6b = Qwen3TTSModel.from_pretrained(
|
| 69 |
+
# get_model_path("Base", "0.6B"),
|
| 70 |
+
# device_map="cuda",
|
| 71 |
+
# dtype=torch.bfloat16,
|
| 72 |
+
# token=HF_TOKEN,
|
| 73 |
+
# attn_implementation="kernels-community/flash-attn3",
|
| 74 |
+
# )
|
| 75 |
|
| 76 |
+
@functools.lru_cache(maxsize=1) # 只缓存当前正在使用的模型,节省显存
|
| 77 |
+
def load_model(model_type, model_size):
|
| 78 |
+
logger.info(f"正在按需加载 {model_type} {model_size} 模型...")
|
| 79 |
+
path = get_model_path(model_type, model_size)
|
| 80 |
+
return Qwen3TTSModel.from_pretrained(
|
| 81 |
+
path,
|
| 82 |
+
device_map="cuda", # 注意:在 ZeroGPU 环境下,这行只有在被装饰的函数内执行才有效
|
| 83 |
+
dtype=torch.bfloat16,
|
| 84 |
+
token=HF_TOKEN,
|
| 85 |
+
attn_implementation="kernels-community/flash-attn3"
|
| 86 |
+
)
|
| 87 |
|
| 88 |
# logger.info("正在加载 Base 1.7B 模型...")
|
| 89 |
# base_model_1_7b = Qwen3TTSModel.from_pretrained(
|
|
|
|
| 209 |
@spaces.GPU
|
| 210 |
def infer_voice_design(part, language, voice_description):
|
| 211 |
"""Single segment inference for Voice Design."""
|
| 212 |
+
voice_design_model = load_model("VoiceDesign","1.7B")
|
| 213 |
+
|
| 214 |
wavs, sr = voice_design_model.generate_voice_design(
|
| 215 |
text=part,
|
| 216 |
language=language,
|
|
|
|
| 222 |
|
| 223 |
|
| 224 |
@spaces.GPU
|
| 225 |
+
def infer_voice_clone( part, language,audio_tuple,ref_text,use_xvector_only):
|
| 226 |
"""Single segment inference for Voice Clone."""
|
| 227 |
# tts = BASE_MODELS[model_size]
|
| 228 |
+
tts = load_model("Base", "0.6B")
|
| 229 |
+
voice_clone_prompt = tts.create_voice_clone_prompt(
|
| 230 |
+
ref_audio=audio_tuple,
|
| 231 |
+
ref_text=ref_text.strip() if ref_text else None,
|
| 232 |
+
x_vector_only_mode=use_xvector_only
|
| 233 |
+
)
|
| 234 |
wavs, sr = tts.generate_voice_clone(
|
| 235 |
text=part,
|
| 236 |
language=language,
|
|
|
|
| 239 |
)
|
| 240 |
return wavs[0], sr
|
| 241 |
|
|
|
|
| 242 |
def extract_voice_clone_prompt(audio_tuple,ref_text,use_xvector_only):
|
| 243 |
logger.info("正在提取参考音频特征(仅执行一次)...")
|
| 244 |
+
tts = load_model("Base", "0.6B")
|
| 245 |
voice_clone_prompt = tts.create_voice_clone_prompt(
|
| 246 |
ref_audio=audio_tuple,
|
| 247 |
ref_text=ref_text.strip() if ref_text else None,
|
|
|
|
| 306 |
|
| 307 |
logger.info(f"开始 Voice Clone 生成任务。模型大小: {model_size}, 语言: {language}, 目标文本长度: {len(target_text)}, 仅使用 x-vector: {use_xvector_only}")
|
| 308 |
try:
|
|
|
|
| 309 |
text_parts = split_text(target_text.strip())
|
| 310 |
logger.info(f"目标文本已切分为 {len(text_parts)} 段。")
|
| 311 |
all_wavs = []
|
|
|
|
| 313 |
|
| 314 |
for i, part in enumerate(progress.tqdm(text_parts, desc="正在生成分段")):
|
| 315 |
logger.info(f"正在处理第 {i+1}/{len(text_parts)} 段文本...")
|
| 316 |
+
wav, current_sr = infer_voice_clone( part, language,audio_tuple,ref_text,use_xvector_only)
|
| 317 |
all_wavs.append(wav)
|
| 318 |
sr = current_sr
|
| 319 |
|