fosters commited on
Commit
5df429e
·
verified ·
1 Parent(s): ce8d2b9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -73
app.py CHANGED
@@ -1,6 +1,11 @@
1
  """
2
- Alternative XTTSv2 loader - loads fine-tuned model from Hugging Face Hub
3
- Use this if your model is hosted on HF Hub instead of locally in the Space
 
 
 
 
 
4
  """
5
 
6
  import gradio as gr
@@ -8,8 +13,10 @@ import torch
8
  import os
9
  import gc
10
  import hashlib
 
11
  import numpy as np
12
- from huggingface_hub import snapshot_download
 
13
  from typing import Optional, Tuple
14
  import logging
15
 
@@ -17,42 +24,43 @@ logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  # ============== Configuration ==============
20
- # Change this to your HF Hub model repo
21
- HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "your-username/your-xtts-finetuned")
22
- USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", "true").lower() == "true"
23
  USE_FP16 = os.environ.get("USE_FP16", "true").lower() == "true"
24
- USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "true").lower() == "true"
25
- MAX_CACHE_SIZE = int(os.environ.get("MAX_CACHE_SIZE", "10"))
26
  STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20"))
27
 
28
  # ============== Model Loading ==============
29
- def download_and_load_model():
30
- """Download model from HF Hub and load with optimizations"""
 
31
  from TTS.tts.configs.xtts_config import XttsConfig
32
  from TTS.tts.models.xtts import Xtts
33
 
34
- logger.info(f"Downloading model from {HF_MODEL_REPO}...")
35
 
36
- # Download model files from HF Hub
37
- model_path = snapshot_download(
38
- repo_id=HF_MODEL_REPO,
39
- allow_patterns=["*.pth", "*.json", "*.txt", "vocab.*"],
40
- local_dir="./model",
41
- local_dir_use_symlinks=False
42
- )
43
-
44
- logger.info(f"Model downloaded to {model_path}")
45
-
46
- config = XttsConfig()
47
- config.load_json(os.path.join(model_path, "config.json"))
48
 
49
- model = Xtts.init_from_config(config)
50
- model.load_checkpoint(
51
- config,
52
- checkpoint_dir=model_path,
53
- eval=True,
54
- use_deepspeed=USE_DEEPSPEED
55
- )
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
  model = model.to(device)
@@ -61,6 +69,7 @@ def download_and_load_model():
61
  if USE_FP16 and device == "cuda":
62
  logger.info("Enabling FP16 inference...")
63
  model.half()
 
64
  if hasattr(model, 'gpt'):
65
  model.gpt.float()
66
 
@@ -83,22 +92,26 @@ def download_and_load_model():
83
  return model, config, device
84
 
85
  # Global model instance
86
- model, config, device = download_and_load_model()
87
 
88
  # ============== Speaker Caching ==============
89
  class SpeakerCache:
 
 
90
  def __init__(self, max_size: int = 10):
91
  self.max_size = max_size
92
  self.cache = {}
93
  self.order = []
94
 
95
  def _hash_audio(self, audio_path: str) -> str:
 
96
  with open(audio_path, 'rb') as f:
97
  return hashlib.md5(f.read()).hexdigest()[:16]
98
 
99
  def get(self, audio_path: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
100
  key = self._hash_audio(audio_path)
101
  if key in self.cache:
 
102
  self.order.remove(key)
103
  self.order.append(key)
104
  return self.cache[key]
@@ -106,12 +119,15 @@ class SpeakerCache:
106
 
107
  def set(self, audio_path: str, latents: Tuple[torch.Tensor, torch.Tensor]):
108
  key = self._hash_audio(audio_path)
 
 
109
  if len(self.cache) >= self.max_size and key not in self.cache:
110
  oldest = self.order.pop(0)
111
  del self.cache[oldest]
112
  gc.collect()
113
  if torch.cuda.is_available():
114
  torch.cuda.empty_cache()
 
115
  self.cache[key] = latents
116
  if key not in self.order:
117
  self.order.append(key)
@@ -128,6 +144,9 @@ speaker_cache = SpeakerCache(max_size=MAX_CACHE_SIZE)
128
  # ============== Core Functions ==============
129
  @torch.inference_mode()
130
  def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
131
  cached = speaker_cache.get(speaker_wav)
132
  if cached is not None:
133
  logger.info("Using cached speaker latents")
@@ -136,12 +155,13 @@ def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]:
136
  logger.info("Computing speaker latents...")
137
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
138
  audio_path=speaker_wav,
139
- gpt_cond_len=getattr(config, 'gpt_cond_len', 6),
140
- gpt_cond_chunk_len=getattr(config, 'gpt_cond_chunk_len', 3),
141
- max_ref_length=getattr(config, 'max_ref_len', 30),
142
- sound_norm_refs=getattr(config, 'sound_norm_refs', False),
143
  )
144
 
 
145
  if USE_FP16 and device == "cuda":
146
  gpt_cond_latent = gpt_cond_latent.half()
147
  speaker_embedding = speaker_embedding.half()
@@ -162,7 +182,11 @@ def synthesize(
162
  length_penalty: float = 1.0,
163
  speed: float = 1.0
164
  ) -> Optional[Tuple[int, np.ndarray]]:
165
- if not text.strip() or not speaker_wav:
 
 
 
 
166
  return None
167
 
168
  try:
@@ -183,7 +207,8 @@ def synthesize(
183
  )
184
 
185
  wav = np.array(out["wav"])
186
- sample_rate = getattr(config.audio, 'output_sample_rate', 24000)
 
187
  return (sample_rate, wav)
188
 
189
  except Exception as e:
@@ -202,6 +227,8 @@ def synthesize_streaming(
202
  repetition_penalty: float = 5.0,
203
  speed: float = 1.0
204
  ):
 
 
205
  if not text.strip() or not speaker_wav:
206
  return
207
 
@@ -222,7 +249,7 @@ def synthesize_streaming(
222
  enable_text_splitting=True
223
  )
224
 
225
- sample_rate = getattr(config.audio, 'output_sample_rate', 24000)
226
 
227
  for chunk in chunks:
228
  if chunk is not None:
@@ -234,42 +261,85 @@ def synthesize_streaming(
234
 
235
 
236
  def clear_cache():
 
237
  speaker_cache.clear()
238
  return "Cache cleared!"
239
 
240
 
241
  # ============== Gradio Interface ==============
242
  LANGUAGES = [
243
- ("English", "en"), ("Spanish", "es"), ("French", "fr"), ("German", "de"),
244
- ("Italian", "it"), ("Portuguese", "pt"), ("Polish", "pl"), ("Turkish", "tr"),
245
- ("Russian", "ru"), ("Dutch", "nl"), ("Czech", "cs"), ("Arabic", "ar"),
246
- ("Chinese", "zh-cn"), ("Japanese", "ja"), ("Hungarian", "hu"), ("Korean", "ko"),
 
 
 
 
 
 
 
 
 
 
 
 
247
  ("Hindi", "hi"),
248
  ]
249
 
250
- with gr.Blocks(title="🐸 XTTSv2 TTS", theme=gr.themes.Soft()) as demo:
251
- gr.Markdown("# 🐸 XTTSv2 Text-to-Speech\nHigh-quality multilingual voice cloning.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  with gr.Tabs():
 
254
  with gr.TabItem("🎙️ Standard"):
255
  with gr.Row():
256
- with gr.Column():
257
- text_input = gr.Textbox(label="Text", placeholder="Enter text...", lines=4)
258
- speaker_wav = gr.Audio(label="Reference Audio", type="filepath")
259
- language = gr.Dropdown(choices=LANGUAGES, value="en", label="Language")
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- with gr.Accordion("Advanced", open=False):
262
- temperature = gr.Slider(0.1, 1.0, value=0.65, label="Temperature")
263
- top_p = gr.Slider(0.1, 1.0, value=0.85, label="Top P")
264
- top_k = gr.Slider(1, 100, value=50, label="Top K")
265
- repetition_penalty = gr.Slider(1.0, 15.0, value=5.0, label="Repetition Penalty")
266
- length_penalty = gr.Slider(0.5, 2.0, value=1.0, label="Length Penalty")
267
- speed = gr.Slider(0.5, 2.0, value=1.0, label="Speed")
268
 
269
- generate_btn = gr.Button("🔊 Generate", variant="primary")
270
 
271
- with gr.Column():
272
- audio_output = gr.Audio(label="Output")
273
 
274
  generate_btn.click(
275
  fn=synthesize,
@@ -277,28 +347,71 @@ with gr.Blocks(title="🐸 XTTSv2 TTS", theme=gr.themes.Soft()) as demo:
277
  outputs=audio_output
278
  )
279
 
280
- with gr.TabItem("⚡ Streaming"):
 
281
  with gr.Row():
282
- with gr.Column():
283
- text_stream = gr.Textbox(label="Text", lines=4)
284
- speaker_stream = gr.Audio(label="Reference Audio", type="filepath")
285
- lang_stream = gr.Dropdown(choices=LANGUAGES, value="en", label="Language")
286
- stream_btn = gr.Button("⚡ Stream", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- with gr.Column():
289
- audio_stream = gr.Audio(label="Output", streaming=True, autoplay=True)
290
 
291
  stream_btn.click(
292
  fn=synthesize_streaming,
293
- inputs=[text_stream, speaker_stream, lang_stream],
294
- outputs=audio_stream
295
  )
296
 
 
297
  with gr.TabItem("⚙️ Settings"):
298
- gr.Markdown(f"**Device**: {device} | **DeepSpeed**: {USE_DEEPSPEED} | **FP16**: {USE_FP16}")
299
- clear_btn = gr.Button("🗑️ Clear Cache")
300
- status = gr.Textbox(label="Status", interactive=False)
301
- clear_btn.click(fn=clear_cache, outputs=status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  if __name__ == "__main__":
304
- demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  """
2
+ Optimized XTTSv2 Hugging Face Space
3
+ - DeepSpeed acceleration
4
+ - FP16 inference
5
+ - torch.compile() optimization
6
+ - Speaker latent caching
7
+ - Streaming inference
8
+ - Memory optimization
9
  """
10
 
11
  import gradio as gr
 
13
  import os
14
  import gc
15
  import hashlib
16
+ import tempfile
17
  import numpy as np
18
+ from pathlib import Path
19
+ from functools import lru_cache
20
  from typing import Optional, Tuple
21
  import logging
22
 
 
24
  logger = logging.getLogger(__name__)
25
 
26
  # ============== Configuration ==============
27
+ MODEL_PATH = os.environ.get("MODEL_PATH", "./model")
28
+ USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", "false").lower() == "true" # Disabled by default for stability
 
29
  USE_FP16 = os.environ.get("USE_FP16", "true").lower() == "true"
30
+ USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "false").lower() == "true" # Disabled by default for stability
31
+ MAX_CACHE_SIZE = int(os.environ.get("MAX_CACHE_SIZE", "10")) # Max cached speakers
32
  STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20"))
33
 
34
  # ============== Model Loading ==============
35
+ def load_model():
36
+ """Load XTTSv2 with all optimizations"""
37
+ from TTS.api import TTS
38
  from TTS.tts.configs.xtts_config import XttsConfig
39
  from TTS.tts.models.xtts import Xtts
40
 
41
+ logger.info("Loading XTTSv2 model...")
42
 
43
+ # Check if local model exists, otherwise use default from HF Hub
44
+ local_config = os.path.join(MODEL_PATH, "config.json")
 
 
 
 
 
 
 
 
 
 
45
 
46
+ if os.path.exists(local_config):
47
+ # Load local/fine-tuned model
48
+ logger.info(f"Loading local model from {MODEL_PATH}")
49
+ config = XttsConfig()
50
+ config.load_json(local_config)
51
+ model = Xtts.init_from_config(config)
52
+ model.load_checkpoint(
53
+ config,
54
+ checkpoint_dir=MODEL_PATH,
55
+ eval=True,
56
+ use_deepspeed=USE_DEEPSPEED
57
+ )
58
+ else:
59
+ # Load default XTTS-v2 from Hugging Face Hub via TTS API
60
+ logger.info("Loading default coqui/XTTS-v2 model from Hugging Face Hub...")
61
+ tts_api = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=torch.cuda.is_available())
62
+ model = tts_api.synthesizer.tts_model
63
+ config = tts_api.synthesizer.tts_config
64
 
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
  model = model.to(device)
 
69
  if USE_FP16 and device == "cuda":
70
  logger.info("Enabling FP16 inference...")
71
  model.half()
72
+ # Keep some layers in FP32 for stability
73
  if hasattr(model, 'gpt'):
74
  model.gpt.float()
75
 
 
92
  return model, config, device
93
 
94
  # Global model instance
95
+ model, config, device = load_model()
96
 
97
  # ============== Speaker Caching ==============
98
  class SpeakerCache:
99
+ """LRU cache for speaker embeddings with hash-based keys"""
100
+
101
  def __init__(self, max_size: int = 10):
102
  self.max_size = max_size
103
  self.cache = {}
104
  self.order = []
105
 
106
  def _hash_audio(self, audio_path: str) -> str:
107
+ """Create hash from audio file for cache key"""
108
  with open(audio_path, 'rb') as f:
109
  return hashlib.md5(f.read()).hexdigest()[:16]
110
 
111
  def get(self, audio_path: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
112
  key = self._hash_audio(audio_path)
113
  if key in self.cache:
114
+ # Move to end (most recently used)
115
  self.order.remove(key)
116
  self.order.append(key)
117
  return self.cache[key]
 
119
 
120
  def set(self, audio_path: str, latents: Tuple[torch.Tensor, torch.Tensor]):
121
  key = self._hash_audio(audio_path)
122
+
123
+ # Evict oldest if at capacity
124
  if len(self.cache) >= self.max_size and key not in self.cache:
125
  oldest = self.order.pop(0)
126
  del self.cache[oldest]
127
  gc.collect()
128
  if torch.cuda.is_available():
129
  torch.cuda.empty_cache()
130
+
131
  self.cache[key] = latents
132
  if key not in self.order:
133
  self.order.append(key)
 
144
  # ============== Core Functions ==============
145
  @torch.inference_mode()
146
  def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ """Get speaker conditioning with caching"""
148
+
149
+ # Check cache first
150
  cached = speaker_cache.get(speaker_wav)
151
  if cached is not None:
152
  logger.info("Using cached speaker latents")
 
155
  logger.info("Computing speaker latents...")
156
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
157
  audio_path=speaker_wav,
158
+ gpt_cond_len=config.gpt_cond_len if hasattr(config, 'gpt_cond_len') else 6,
159
+ gpt_cond_chunk_len=config.gpt_cond_chunk_len if hasattr(config, 'gpt_cond_chunk_len') else 3,
160
+ max_ref_length=config.max_ref_len if hasattr(config, 'max_ref_len') else 30,
161
+ sound_norm_refs=config.sound_norm_refs if hasattr(config, 'sound_norm_refs') else False,
162
  )
163
 
164
+ # Move to correct device and dtype
165
  if USE_FP16 and device == "cuda":
166
  gpt_cond_latent = gpt_cond_latent.half()
167
  speaker_embedding = speaker_embedding.half()
 
182
  length_penalty: float = 1.0,
183
  speed: float = 1.0
184
  ) -> Optional[Tuple[int, np.ndarray]]:
185
+ """Standard synthesis with optimizations"""
186
+
187
+ if not text.strip():
188
+ return None
189
+ if not speaker_wav:
190
  return None
191
 
192
  try:
 
207
  )
208
 
209
  wav = np.array(out["wav"])
210
+ sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000
211
+
212
  return (sample_rate, wav)
213
 
214
  except Exception as e:
 
227
  repetition_penalty: float = 5.0,
228
  speed: float = 1.0
229
  ):
230
+ """Streaming synthesis for lower latency"""
231
+
232
  if not text.strip() or not speaker_wav:
233
  return
234
 
 
249
  enable_text_splitting=True
250
  )
251
 
252
+ sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000
253
 
254
  for chunk in chunks:
255
  if chunk is not None:
 
261
 
262
 
263
  def clear_cache():
264
+ """Clear speaker cache and CUDA memory"""
265
  speaker_cache.clear()
266
  return "Cache cleared!"
267
 
268
 
269
  # ============== Gradio Interface ==============
270
  LANGUAGES = [
271
+ ("English", "en"),
272
+ ("Spanish", "es"),
273
+ ("French", "fr"),
274
+ ("German", "de"),
275
+ ("Italian", "it"),
276
+ ("Portuguese", "pt"),
277
+ ("Polish", "pl"),
278
+ ("Turkish", "tr"),
279
+ ("Russian", "ru"),
280
+ ("Dutch", "nl"),
281
+ ("Czech", "cs"),
282
+ ("Arabic", "ar"),
283
+ ("Chinese", "zh-cn"),
284
+ ("Japanese", "ja"),
285
+ ("Hungarian", "hu"),
286
+ ("Korean", "ko"),
287
  ("Hindi", "hi"),
288
  ]
289
 
290
+ css = """
291
+ .generate-btn {
292
+ background: linear-gradient(90deg, #4CAF50 0%, #45a049 100%) !important;
293
+ border: none !important;
294
+ }
295
+ .generate-btn:hover {
296
+ background: linear-gradient(90deg, #45a049 0%, #3d8b40 100%) !important;
297
+ }
298
+ footer {visibility: hidden}
299
+ """
300
+
301
+ with gr.Blocks(title="🐸 XTTSv2 TTS", css=css, theme=gr.themes.Soft()) as demo:
302
+ gr.Markdown("""
303
+ # 🐸 XTTSv2 Text-to-Speech
304
+
305
+ High-quality multilingual voice cloning with optimized inference.
306
+ Upload a reference audio (6+ seconds recommended) and enter your text.
307
+ """)
308
 
309
  with gr.Tabs():
310
+ # Standard Tab
311
  with gr.TabItem("🎙️ Standard"):
312
  with gr.Row():
313
+ with gr.Column(scale=1):
314
+ text_input = gr.Textbox(
315
+ label="Text to synthesize",
316
+ placeholder="Enter text here...",
317
+ lines=4,
318
+ max_lines=10
319
+ )
320
+ speaker_wav = gr.Audio(
321
+ label="Reference Audio",
322
+ type="filepath",
323
+ sources=["upload", "microphone"]
324
+ )
325
+ language = gr.Dropdown(
326
+ choices=LANGUAGES,
327
+ value="en",
328
+ label="Language"
329
+ )
330
 
331
+ with gr.Accordion("Advanced Settings", open=False):
332
+ temperature = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature")
333
+ top_p = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P")
334
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top K")
335
+ repetition_penalty = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty")
336
+ length_penalty = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Length Penalty")
337
+ speed = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed")
338
 
339
+ generate_btn = gr.Button("🔊 Generate Speech", variant="primary", elem_classes=["generate-btn"])
340
 
341
+ with gr.Column(scale=1):
342
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
343
 
344
  generate_btn.click(
345
  fn=synthesize,
 
347
  outputs=audio_output
348
  )
349
 
350
+ # Streaming Tab
351
+ with gr.TabItem("⚡ Streaming (Low Latency)"):
352
  with gr.Row():
353
+ with gr.Column(scale=1):
354
+ text_input_stream = gr.Textbox(
355
+ label="Text to synthesize",
356
+ placeholder="Enter text here...",
357
+ lines=4
358
+ )
359
+ speaker_wav_stream = gr.Audio(
360
+ label="Reference Audio",
361
+ type="filepath",
362
+ sources=["upload", "microphone"]
363
+ )
364
+ language_stream = gr.Dropdown(
365
+ choices=LANGUAGES,
366
+ value="en",
367
+ label="Language"
368
+ )
369
+
370
+ with gr.Accordion("Advanced Settings", open=False):
371
+ temp_stream = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature")
372
+ top_p_stream = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P")
373
+ top_k_stream = gr.Slider(1, 100, value=50, step=1, label="Top K")
374
+ rep_pen_stream = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty")
375
+ speed_stream = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed")
376
+
377
+ stream_btn = gr.Button("⚡ Stream Speech", variant="primary")
378
 
379
+ with gr.Column(scale=1):
380
+ audio_output_stream = gr.Audio(label="Streaming Output", streaming=True, autoplay=True)
381
 
382
  stream_btn.click(
383
  fn=synthesize_streaming,
384
+ inputs=[text_input_stream, speaker_wav_stream, language_stream, temp_stream, top_p_stream, top_k_stream, rep_pen_stream, speed_stream],
385
+ outputs=audio_output_stream
386
  )
387
 
388
+ # Settings Tab
389
  with gr.TabItem("⚙️ Settings"):
390
+ gr.Markdown(f"""
391
+ ### Current Configuration
392
+ - **Device**: {device}
393
+ - **DeepSpeed**: {'Enabled' if USE_DEEPSPEED else 'Disabled'}
394
+ - **FP16**: {'Enabled' if USE_FP16 else 'Disabled'}
395
+ - **torch.compile**: {'Enabled' if USE_TORCH_COMPILE else 'Disabled'}
396
+ - **Max Cached Speakers**: {MAX_CACHE_SIZE}
397
+ """)
398
+
399
+ clear_cache_btn = gr.Button("🗑️ Clear Speaker Cache")
400
+ cache_status = gr.Textbox(label="Status", interactive=False)
401
+
402
+ clear_cache_btn.click(fn=clear_cache, outputs=cache_status)
403
+
404
+ gr.Markdown("""
405
+ ---
406
+ **Tips for best results:**
407
+ - Use clean reference audio with minimal background noise
408
+ - 6-30 seconds of reference audio works best
409
+ - Match the language of your text to your reference audio for best quality
410
+ """)
411
 
412
  if __name__ == "__main__":
413
+ demo.queue(max_size=10).launch(
414
+ server_name="0.0.0.0",
415
+ server_port=7860,
416
+ show_error=True
417
+ )