littlebird13 commited on
Commit
3df69b4
·
verified ·
1 Parent(s): 1cab9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -75
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # coding=utf-8
2
  # Qwen3-TTS Gradio Demo for HuggingFace Spaces with Zero GPU
3
  # Supports: Voice Design, Voice Clone (Base), TTS (CustomVoice)
4
- #import subprocess
5
- #subprocess.run('pip install flash-attn==2.7.4.post1', shell=True)
6
  import os
7
  import spaces
8
  import gradio as gr
@@ -30,74 +30,97 @@ def get_model_path(model_type: str, model_size: str) -> str:
30
 
31
 
32
  # ============================================================================
33
- # GLOBAL MODEL LOADING - Load all models at startup
34
  # ============================================================================
35
- print("Loading all models to CUDA...")
36
-
37
- # Voice Design model (1.7B only)
38
- print("Loading VoiceDesign 1.7B model...")
39
- voice_design_model = Qwen3TTSModel.from_pretrained(
40
- get_model_path("VoiceDesign", "1.7B"),
41
- device_map="cuda",
42
- dtype=torch.bfloat16,
43
- token=HF_TOKEN,
44
- #attn_implementation="kernels-community/flash-attn3",
45
- )
46
-
47
- # Base (Voice Clone) models - both sizes
48
- print("Loading Base 0.6B model...")
49
- base_model_0_6b = Qwen3TTSModel.from_pretrained(
50
- get_model_path("Base", "0.6B"),
51
- device_map="cuda",
52
- dtype=torch.bfloat16,
53
- token=HF_TOKEN,
54
- #attn_implementation="kernels-community/flash-attn3",
55
- )
56
-
57
- print("Loading Base 1.7B model...")
58
- base_model_1_7b = Qwen3TTSModel.from_pretrained(
59
- get_model_path("Base", "1.7B"),
60
- device_map="cuda",
61
- dtype=torch.bfloat16,
62
- token=HF_TOKEN,
63
- #attn_implementation="kernels-community/flash-attn3",
64
- )
65
-
66
- # CustomVoice models - both sizes
67
- print("Loading CustomVoice 0.6B model...")
68
- custom_voice_model_0_6b = Qwen3TTSModel.from_pretrained(
69
- get_model_path("CustomVoice", "0.6B"),
70
- device_map="cuda",
71
- dtype=torch.bfloat16,
72
- token=HF_TOKEN,
73
- attn_implementation="kernels-community/flash-attn3",
74
- )
75
-
76
- print("Loading CustomVoice 1.7B model...")
77
- custom_voice_model_1_7b = Qwen3TTSModel.from_pretrained(
78
- get_model_path("CustomVoice", "1.7B"),
79
- device_map="cuda",
80
- dtype=torch.bfloat16,
81
- token=HF_TOKEN,
82
- attn_implementation="kernels-community/flash-attn3",
83
- )
84
-
85
- print("All models loaded successfully!")
86
-
87
- # Model lookup dictionaries for easy access
88
- BASE_MODELS = {
89
- "0.6B": base_model_0_6b,
90
- "1.7B": base_model_1_7b,
91
- }
92
-
93
- CUSTOM_VOICE_MODELS = {
94
- "0.6B": custom_voice_model_0_6b,
95
- "1.7B": custom_voice_model_1_7b,
96
- }
97
 
98
- # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
 
 
 
 
101
  def _normalize_audio(wav, eps=1e-12, clip=True):
102
  """Normalize audio to float32 in [-1, 1] range."""
103
  x = np.asarray(wav)
@@ -144,7 +167,11 @@ def _audio_to_tuple(audio):
144
  return None
145
 
146
 
147
- @spaces.GPU(duration=60)
 
 
 
 
148
  def generate_voice_design(text, language, voice_description, progress=gr.Progress(track_tqdm=True)):
149
  """Generate speech using Voice Design model (1.7B only)."""
150
  if not text or not text.strip():
@@ -153,7 +180,10 @@ def generate_voice_design(text, language, voice_description, progress=gr.Progres
153
  return None, "Error: Voice description is required."
154
 
155
  try:
156
- wavs, sr = voice_design_model.generate_voice_design(
 
 
 
157
  text=text.strip(),
158
  language=language,
159
  instruct=voice_description.strip(),
@@ -162,10 +192,12 @@ def generate_voice_design(text, language, voice_description, progress=gr.Progres
162
  )
163
  return (sr, wavs[0]), "Voice design generation completed successfully!"
164
  except Exception as e:
 
 
165
  return None, f"Error: {type(e).__name__}: {e}"
166
 
167
 
168
- @spaces.GPU(duration=60)
169
  def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size, progress=gr.Progress(track_tqdm=True)):
170
  """Generate speech using Base (Voice Clone) model."""
171
  if not target_text or not target_text.strip():
@@ -179,8 +211,10 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
179
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
180
 
181
  try:
182
- tts = BASE_MODELS[model_size]
183
- wavs, sr = tts.generate_voice_clone(
 
 
184
  text=target_text.strip(),
185
  language=language,
186
  ref_audio=audio_tuple,
@@ -190,10 +224,12 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
190
  )
191
  return (sr, wavs[0]), "Voice clone generation completed successfully!"
192
  except Exception as e:
 
 
193
  return None, f"Error: {type(e).__name__}: {e}"
194
 
195
 
196
- @spaces.GPU(duration=60)
197
  def generate_custom_voice(text, language, speaker, instruct, model_size, progress=gr.Progress(track_tqdm=True)):
198
  """Generate speech using CustomVoice model."""
199
  if not text or not text.strip():
@@ -202,8 +238,10 @@ def generate_custom_voice(text, language, speaker, instruct, model_size, progres
202
  return None, "Error: Speaker is required."
203
 
204
  try:
205
- tts = CUSTOM_VOICE_MODELS[model_size]
206
- wavs, sr = tts.generate_custom_voice(
 
 
207
  text=text.strip(),
208
  language=language,
209
  speaker=speaker.lower().replace(" ", "_"),
@@ -213,10 +251,15 @@ def generate_custom_voice(text, language, speaker, instruct, model_size, progres
213
  )
214
  return (sr, wavs[0]), "Generation completed successfully!"
215
  except Exception as e:
 
 
216
  return None, f"Error: {type(e).__name__}: {e}"
217
 
218
 
219
- # Build Gradio UI
 
 
 
220
  def build_ui():
221
  theme = gr.themes.Soft(
222
  font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
@@ -235,7 +278,10 @@ A unified Text-to-Speech demo featuring three powerful modes:
235
  - **Voice Design**: Create custom voices using natural language descriptions
236
  - **Voice Clone (Base)**: Clone any voice from a reference audio
237
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
 
238
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
 
 
239
  """
240
  )
241
 
@@ -378,6 +424,9 @@ Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team
378
  ---
379
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
380
  For longer texts, please split them into smaller segments.
 
 
 
381
  """
382
  )
383
 
 
1
  # coding=utf-8
2
  # Qwen3-TTS Gradio Demo for HuggingFace Spaces with Zero GPU
3
  # Supports: Voice Design, Voice Clone (Base), TTS (CustomVoice)
4
+ # Optimized: Load models on demand to save GPU memory
5
+
6
  import os
7
  import spaces
8
  import gradio as gr
 
30
 
31
 
32
  # ============================================================================
33
+ # ON-DEMAND MODEL LOADING - Load models only when needed
34
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Global model cache
37
+ _model_cache = {}
38
+ _current_model_key = None
39
+
40
+
41
+ def print_gpu_memory(msg=""):
42
+ """Print current GPU memory usage."""
43
+ if torch.cuda.is_available():
44
+ allocated = torch.cuda.memory_allocated() / 1e9
45
+ reserved = torch.cuda.memory_reserved() / 1e9
46
+ print(f"[GPU Memory {msg}] Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
47
+
48
+
49
+ def clear_model_cache():
50
+ """Clear all cached models and free GPU memory."""
51
+ global _model_cache, _current_model_key
52
+
53
+ for key in list(_model_cache.keys()):
54
+ print(f"Unloading model: {key}")
55
+ del _model_cache[key]
56
+
57
+ _model_cache = {}
58
+ _current_model_key = None
59
+
60
+ if torch.cuda.is_available():
61
+ torch.cuda.empty_cache()
62
+ torch.cuda.synchronize()
63
+
64
+ import gc
65
+ gc.collect()
66
+
67
+ print_gpu_memory("after clearing cache")
68
+
69
+
70
+ def get_model(model_type: str, model_size: str):
71
+ """
72
+ Load model on demand with caching.
73
+ Only keeps one model in memory at a time to save GPU memory.
74
+
75
+ Args:
76
+ model_type: "VoiceDesign", "Base", or "CustomVoice"
77
+ model_size: "0.6B" or "1.7B"
78
+
79
+ Returns:
80
+ Loaded model
81
+ """
82
+ global _model_cache, _current_model_key
83
+
84
+ cache_key = f"{model_type}_{model_size}"
85
+
86
+ # If requested model is already loaded, return it
87
+ if cache_key in _model_cache:
88
+ print(f"Using cached model: {cache_key}")
89
+ return _model_cache[cache_key]
90
+
91
+ # Clear existing models to free GPU memory
92
+ if _model_cache:
93
+ print(f"Switching from {_current_model_key} to {cache_key}")
94
+ clear_model_cache()
95
+
96
+ print_gpu_memory("before loading")
97
+
98
+ # Load the requested model
99
+ print(f"Loading {model_type} {model_size} model...")
100
+ model_path = get_model_path(model_type, model_size)
101
+
102
+ model = Qwen3TTSModel.from_pretrained(
103
+ model_path,
104
+ device_map="cuda",
105
+ dtype=torch.bfloat16,
106
+ token=HF_TOKEN,
107
+ # Note: Remove flash-attn if you encounter compatibility issues
108
+ # attn_implementation="kernels-community/flash-attn3",
109
+ )
110
+
111
+ _model_cache[cache_key] = model
112
+ _current_model_key = cache_key
113
+
114
+ print_gpu_memory("after loading")
115
+ print(f"Model {cache_key} loaded successfully!")
116
+
117
+ return model
118
 
119
 
120
+ # ============================================================================
121
+ # Audio utility functions
122
+ # ============================================================================
123
+
124
  def _normalize_audio(wav, eps=1e-12, clip=True):
125
  """Normalize audio to float32 in [-1, 1] range."""
126
  x = np.asarray(wav)
 
167
  return None
168
 
169
 
170
+ # ============================================================================
171
+ # Generation functions
172
+ # ============================================================================
173
+
174
+ @spaces.GPU(duration=120) # Increased duration for model loading + generation
175
  def generate_voice_design(text, language, voice_description, progress=gr.Progress(track_tqdm=True)):
176
  """Generate speech using Voice Design model (1.7B only)."""
177
  if not text or not text.strip():
 
180
  return None, "Error: Voice description is required."
181
 
182
  try:
183
+ # Load model on demand
184
+ model = get_model("VoiceDesign", "1.7B")
185
+
186
+ wavs, sr = model.generate_voice_design(
187
  text=text.strip(),
188
  language=language,
189
  instruct=voice_description.strip(),
 
192
  )
193
  return (sr, wavs[0]), "Voice design generation completed successfully!"
194
  except Exception as e:
195
+ import traceback
196
+ traceback.print_exc()
197
  return None, f"Error: {type(e).__name__}: {e}"
198
 
199
 
200
+ @spaces.GPU(duration=120) # Increased duration for model loading + generation
201
  def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size, progress=gr.Progress(track_tqdm=True)):
202
  """Generate speech using Base (Voice Clone) model."""
203
  if not target_text or not target_text.strip():
 
211
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
212
 
213
  try:
214
+ # Load model on demand
215
+ model = get_model("Base", model_size)
216
+
217
+ wavs, sr = model.generate_voice_clone(
218
  text=target_text.strip(),
219
  language=language,
220
  ref_audio=audio_tuple,
 
224
  )
225
  return (sr, wavs[0]), "Voice clone generation completed successfully!"
226
  except Exception as e:
227
+ import traceback
228
+ traceback.print_exc()
229
  return None, f"Error: {type(e).__name__}: {e}"
230
 
231
 
232
+ @spaces.GPU(duration=120) # Increased duration for model loading + generation
233
  def generate_custom_voice(text, language, speaker, instruct, model_size, progress=gr.Progress(track_tqdm=True)):
234
  """Generate speech using CustomVoice model."""
235
  if not text or not text.strip():
 
238
  return None, "Error: Speaker is required."
239
 
240
  try:
241
+ # Load model on demand
242
+ model = get_model("CustomVoice", model_size)
243
+
244
+ wavs, sr = model.generate_custom_voice(
245
  text=text.strip(),
246
  language=language,
247
  speaker=speaker.lower().replace(" ", "_"),
 
251
  )
252
  return (sr, wavs[0]), "Generation completed successfully!"
253
  except Exception as e:
254
+ import traceback
255
+ traceback.print_exc()
256
  return None, f"Error: {type(e).__name__}: {e}"
257
 
258
 
259
+ # ============================================================================
260
+ # Gradio UI
261
+ # ============================================================================
262
+
263
  def build_ui():
264
  theme = gr.themes.Soft(
265
  font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
 
278
  - **Voice Design**: Create custom voices using natural language descriptions
279
  - **Voice Clone (Base)**: Clone any voice from a reference audio
280
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
281
+
282
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
283
+
284
+ > **Note**: Models are loaded on-demand to optimize GPU memory usage. First generation in each mode may take longer due to model loading.
285
  """
286
  )
287
 
 
424
  ---
425
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
426
  For longer texts, please split them into smaller segments.
427
+
428
+ **Memory Optimization**: Models are loaded on-demand and only one model is kept in memory at a time.
429
+ Switching between different models/sizes will automatically unload the previous model.
430
  """
431
  )
432