littlebird13 commited on
Commit
a91a338
·
verified ·
1 Parent(s): d55a774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -139
app.py CHANGED
@@ -1,125 +1,47 @@
1
  # coding=utf-8
2
  # Qwen3-TTS Gradio Demo for HuggingFace Spaces with Zero GPU
3
  # Supports: Voice Design, Voice Clone (Base), TTS (CustomVoice)
4
- # Optimized: Load models on demand to save GPU memory
5
-
6
  import os
7
  import spaces
8
  import gradio as gr
9
  import numpy as np
10
  import torch
11
- from huggingface_hub import snapshot_download, login
12
- from qwen_tts import Qwen3TTSModel
13
 
 
14
  HF_TOKEN = os.environ.get('HF_TOKEN')
15
  login(token=HF_TOKEN)
16
 
 
 
 
17
  # Model size options
18
  MODEL_SIZES = ["0.6B", "1.7B"]
19
 
20
- # Speaker and language choices for CustomVoice model
21
- SPEAKERS = [
22
- "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
23
- ]
24
- LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
25
-
26
 
27
  def get_model_path(model_type: str, model_size: str) -> str:
28
  """Get model path based on type and size."""
29
  return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
30
 
31
 
32
- # ============================================================================
33
- # ON-DEMAND MODEL LOADING - Load models only when needed
34
- # ============================================================================
35
-
36
- # Global model cache
37
- _model_cache = {}
38
- _current_model_key = None
39
-
40
-
41
- def print_gpu_memory(msg=""):
42
- """Print current GPU memory usage."""
43
- if torch.cuda.is_available():
44
- allocated = torch.cuda.memory_allocated() / 1e9
45
- reserved = torch.cuda.memory_reserved() / 1e9
46
- print(f"[GPU Memory {msg}] Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
47
-
48
-
49
- def clear_model_cache():
50
- """Clear all cached models and free GPU memory."""
51
- global _model_cache, _current_model_key
52
-
53
- for key in list(_model_cache.keys()):
54
- print(f"Unloading model: {key}")
55
- del _model_cache[key]
56
-
57
- _model_cache = {}
58
- _current_model_key = None
59
-
60
- if torch.cuda.is_available():
61
- torch.cuda.empty_cache()
62
- torch.cuda.synchronize()
63
-
64
- import gc
65
- gc.collect()
66
-
67
- print_gpu_memory("after clearing cache")
68
-
69
-
70
  def get_model(model_type: str, model_size: str):
71
- """
72
- Load model on demand with caching.
73
- Only keeps one model in memory at a time to save GPU memory.
74
-
75
- Args:
76
- model_type: "VoiceDesign", "Base", or "CustomVoice"
77
- model_size: "0.6B" or "1.7B"
78
-
79
- Returns:
80
- Loaded model
81
- """
82
- global _model_cache, _current_model_key
83
-
84
- cache_key = f"{model_type}_{model_size}"
85
-
86
- # If requested model is already loaded, return it
87
- if cache_key in _model_cache:
88
- print(f"Using cached model: {cache_key}")
89
- return _model_cache[cache_key]
90
-
91
- # Clear existing models to free GPU memory
92
- if _model_cache:
93
- print(f"Switching from {_current_model_key} to {cache_key}")
94
- clear_model_cache()
95
-
96
- print_gpu_memory("before loading")
97
-
98
- # Load the requested model
99
- print(f"Loading {model_type} {model_size} model...")
100
- model_path = get_model_path(model_type, model_size)
101
-
102
- model = Qwen3TTSModel.from_pretrained(
103
- model_path,
104
- device_map="cuda",
105
- dtype=torch.bfloat16,
106
- token=HF_TOKEN,
107
- # Note: Remove flash-attn if you encounter compatibility issues
108
- # attn_implementation="kernels-community/flash-attn3",
109
- )
110
-
111
- _model_cache[cache_key] = model
112
- _current_model_key = cache_key
113
-
114
- print_gpu_memory("after loading")
115
- print(f"Model {cache_key} loaded successfully!")
116
-
117
- return model
118
-
119
 
120
- # ============================================================================
121
- # Audio utility functions
122
- # ============================================================================
123
 
124
  def _normalize_audio(wav, eps=1e-12, clip=True):
125
  """Normalize audio to float32 in [-1, 1] range."""
@@ -167,12 +89,15 @@ def _audio_to_tuple(audio):
167
  return None
168
 
169
 
170
- # ============================================================================
171
- # Generation functions
172
- # ============================================================================
 
 
 
173
 
174
- @spaces.GPU(duration=120) # Increased duration for model loading + generation
175
- def generate_voice_design(text, language, voice_description, progress=gr.Progress(track_tqdm=True)):
176
  """Generate speech using Voice Design model (1.7B only)."""
177
  if not text or not text.strip():
178
  return None, "Error: Text is required."
@@ -180,25 +105,21 @@ def generate_voice_design(text, language, voice_description, progress=gr.Progres
180
  return None, "Error: Voice description is required."
181
 
182
  try:
183
- # Load model on demand
184
- model = get_model("VoiceDesign", "1.7B")
185
-
186
- wavs, sr = model.generate_voice_design(
187
  text=text.strip(),
188
  language=language,
189
  instruct=voice_description.strip(),
190
  non_streaming_mode=True,
191
- max_new_tokens=1024,
192
  )
193
  return (sr, wavs[0]), "Voice design generation completed successfully!"
194
  except Exception as e:
195
- import traceback
196
- traceback.print_exc()
197
  return None, f"Error: {type(e).__name__}: {e}"
198
 
199
 
200
- @spaces.GPU(duration=120) # Increased duration for model loading + generation
201
- def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size, progress=gr.Progress(track_tqdm=True)):
202
  """Generate speech using Base (Voice Clone) model."""
203
  if not target_text or not target_text.strip():
204
  return None, "Error: Target text is required."
@@ -211,26 +132,22 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
211
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
212
 
213
  try:
214
- # Load model on demand
215
- model = get_model("Base", model_size)
216
-
217
- wavs, sr = model.generate_voice_clone(
218
  text=target_text.strip(),
219
  language=language,
220
  ref_audio=audio_tuple,
221
  ref_text=ref_text.strip() if ref_text else None,
222
  x_vector_only_mode=use_xvector_only,
223
- max_new_tokens=1024,
224
  )
225
  return (sr, wavs[0]), "Voice clone generation completed successfully!"
226
  except Exception as e:
227
- import traceback
228
- traceback.print_exc()
229
  return None, f"Error: {type(e).__name__}: {e}"
230
 
231
 
232
- @spaces.GPU(duration=120) # Increased duration for model loading + generation
233
- def generate_custom_voice(text, language, speaker, instruct, model_size, progress=gr.Progress(track_tqdm=True)):
234
  """Generate speech using CustomVoice model."""
235
  if not text or not text.strip():
236
  return None, "Error: Text is required."
@@ -238,28 +155,21 @@ def generate_custom_voice(text, language, speaker, instruct, model_size, progres
238
  return None, "Error: Speaker is required."
239
 
240
  try:
241
- # Load model on demand
242
- model = get_model("CustomVoice", model_size)
243
-
244
- wavs, sr = model.generate_custom_voice(
245
  text=text.strip(),
246
  language=language,
247
  speaker=speaker.lower().replace(" ", "_"),
248
  instruct=instruct.strip() if instruct else None,
249
  non_streaming_mode=True,
250
- max_new_tokens=1024,
251
  )
252
  return (sr, wavs[0]), "Generation completed successfully!"
253
  except Exception as e:
254
- import traceback
255
- traceback.print_exc()
256
  return None, f"Error: {type(e).__name__}: {e}"
257
 
258
 
259
- # ============================================================================
260
- # Gradio UI
261
- # ============================================================================
262
-
263
  def build_ui():
264
  theme = gr.themes.Soft(
265
  font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
@@ -278,10 +188,7 @@ A unified Text-to-Speech demo featuring three powerful modes:
278
  - **Voice Design**: Create custom voices using natural language descriptions
279
  - **Voice Clone (Base)**: Clone any voice from a reference audio
280
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
281
-
282
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
283
-
284
- > **Note**: Models are loaded on-demand to optimize GPU memory usage. First generation in each mode may take longer due to model loading.
285
  """
286
  )
287
 
@@ -424,9 +331,6 @@ Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team
424
  ---
425
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
426
  For longer texts, please split them into smaller segments.
427
-
428
- **Memory Optimization**: Models are loaded on-demand and only one model is kept in memory at a time.
429
- Switching between different models/sizes will automatically unload the previous model.
430
  """
431
  )
432
 
 
1
  # coding=utf-8
2
  # Qwen3-TTS Gradio Demo for HuggingFace Spaces with Zero GPU
3
  # Supports: Voice Design, Voice Clone (Base), TTS (CustomVoice)
4
+ #import subprocess
5
+ #subprocess.run('pip install flash-attn==2.7.4.post1', shell=True)
6
  import os
7
  import spaces
8
  import gradio as gr
9
  import numpy as np
10
  import torch
11
+ from huggingface_hub import snapshot_download
 
12
 
13
+ from huggingface_hub import login
14
  HF_TOKEN = os.environ.get('HF_TOKEN')
15
  login(token=HF_TOKEN)
16
 
17
+ # Global model holders - keyed by (model_type, model_size)
18
+ loaded_models = {}
19
+
20
  # Model size options
21
  MODEL_SIZES = ["0.6B", "1.7B"]
22
 
 
 
 
 
 
 
23
 
24
  def get_model_path(model_type: str, model_size: str) -> str:
25
  """Get model path based on type and size."""
26
  return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_model(model_type: str, model_size: str):
30
+ """Get or load a model by type and size."""
31
+ global loaded_models
32
+ key = (model_type, model_size)
33
+ if key not in loaded_models:
34
+ from qwen_tts import Qwen3TTSModel
35
+ model_path = get_model_path(model_type, model_size)
36
+ loaded_models[key] = Qwen3TTSModel.from_pretrained(
37
+ model_path,
38
+ device_map="cuda",
39
+ dtype=torch.bfloat16,
40
+ token=HF_TOKEN,
41
+ # attn_implementation="flash_attention_2",
42
+ )
43
+ return loaded_models[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
45
 
46
  def _normalize_audio(wav, eps=1e-12, clip=True):
47
  """Normalize audio to float32 in [-1, 1] range."""
 
89
  return None
90
 
91
 
92
+ # Speaker and language choices for CustomVoice model
93
+ SPEAKERS = [
94
+ "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
95
+ ]
96
+ LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
97
+
98
 
99
+ @spaces.GPU(duration=60)
100
+ def generate_voice_design(text, language, voice_description):
101
  """Generate speech using Voice Design model (1.7B only)."""
102
  if not text or not text.strip():
103
  return None, "Error: Text is required."
 
105
  return None, "Error: Voice description is required."
106
 
107
  try:
108
+ tts = get_model("VoiceDesign", "1.7B")
109
+ wavs, sr = tts.generate_voice_design(
 
 
110
  text=text.strip(),
111
  language=language,
112
  instruct=voice_description.strip(),
113
  non_streaming_mode=True,
114
+ max_new_tokens=2048,
115
  )
116
  return (sr, wavs[0]), "Voice design generation completed successfully!"
117
  except Exception as e:
 
 
118
  return None, f"Error: {type(e).__name__}: {e}"
119
 
120
 
121
+ @spaces.GPU(duration=60)
122
+ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
123
  """Generate speech using Base (Voice Clone) model."""
124
  if not target_text or not target_text.strip():
125
  return None, "Error: Target text is required."
 
132
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
133
 
134
  try:
135
+ tts = get_model("Base", model_size)
136
+ wavs, sr = tts.generate_voice_clone(
 
 
137
  text=target_text.strip(),
138
  language=language,
139
  ref_audio=audio_tuple,
140
  ref_text=ref_text.strip() if ref_text else None,
141
  x_vector_only_mode=use_xvector_only,
142
+ max_new_tokens=2048,
143
  )
144
  return (sr, wavs[0]), "Voice clone generation completed successfully!"
145
  except Exception as e:
 
 
146
  return None, f"Error: {type(e).__name__}: {e}"
147
 
148
 
149
+ @spaces.GPU(duration=60)
150
+ def generate_custom_voice(text, language, speaker, instruct, model_size):
151
  """Generate speech using CustomVoice model."""
152
  if not text or not text.strip():
153
  return None, "Error: Text is required."
 
155
  return None, "Error: Speaker is required."
156
 
157
  try:
158
+ tts = get_model("CustomVoice", model_size)
159
+ wavs, sr = tts.generate_custom_voice(
 
 
160
  text=text.strip(),
161
  language=language,
162
  speaker=speaker.lower().replace(" ", "_"),
163
  instruct=instruct.strip() if instruct else None,
164
  non_streaming_mode=True,
165
+ max_new_tokens=2048,
166
  )
167
  return (sr, wavs[0]), "Generation completed successfully!"
168
  except Exception as e:
 
 
169
  return None, f"Error: {type(e).__name__}: {e}"
170
 
171
 
172
+ # Build Gradio UI
 
 
 
173
  def build_ui():
174
  theme = gr.themes.Soft(
175
  font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
 
188
  - **Voice Design**: Create custom voices using natural language descriptions
189
  - **Voice Clone (Base)**: Clone any voice from a reference audio
190
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
 
191
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
 
 
192
  """
193
  )
194
 
 
331
  ---
332
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
333
  For longer texts, please split them into smaller segments.
 
 
 
334
  """
335
  )
336