""" Optimized XTTSv2 Hugging Face Space - DeepSpeed acceleration - FP16 inference - torch.compile() optimization - Speaker latent caching - Streaming inference - Memory optimization """ import gradio as gr import torch import os import gc import hashlib import tempfile import numpy as np from pathlib import Path from functools import lru_cache from typing import Optional, Tuple import logging import functools logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) torch.load = functools.partial(torch.load, weights_only=False) # Auto-accept Coqui TOS for non-interactive environments os.environ["COQUI_TOS_AGREED"] = "1" # ============== Configuration ============== MODEL_PATH = os.environ.get("MODEL_PATH", "./model") USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", "false").lower() == "true" # Disabled by default for stability USE_FP16 = os.environ.get("USE_FP16", "true").lower() == "true" USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "false").lower() == "true" # Disabled by default for stability MAX_CACHE_SIZE = int(os.environ.get("MAX_CACHE_SIZE", "10")) # Max cached speakers STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20")) # ============== Model Loading ============== def load_model(): """Load XTTSv2 with all optimizations""" # Import inside function to prevent early CUDA initialization from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.api import TTS logger.info("Loading XTTSv2 model...") # Check if local model exists local_config = os.path.join(MODEL_PATH, "config.json") device = "cuda" if torch.cuda.is_available() else "cpu" if os.path.exists(local_config): config = XttsConfig() config.load_json(local_config) model = Xtts.init_from_config(config) model.load_checkpoint( config, checkpoint_dir=MODEL_PATH, eval=True, use_deepspeed=USE_DEEPSPEED ) else: # Reverting to the high-level API for Hub loads as it handles weights better logger.info("Loading default coqui/XTTS-v2 from Hub...") # We use the synthesizer directly to access the model object for optimizations tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) model = tts.synthesizer.tts_model config = tts.synthesizer.tts_config model.to(device) if USE_FP16 and device == "cuda": logger.info("Enabling FP16 inference...") model.half() # Logic for torch.compile (requires Triton for some features) if USE_TORCH_COMPILE and hasattr(torch, 'compile'): try: # We only compile the GPT part as it's the bottleneck model.gpt = torch.compile(model.gpt, mode="reduce-overhead") logger.info("GPT compiled successfully.") except Exception as e: logger.warning(f"torch.compile failed, skipping: {e}") model.eval() return model, config, device # Global model instance model, config, device = load_model() # ============== Speaker Caching ============== class SpeakerCache: """LRU cache for speaker embeddings with hash-based keys""" def __init__(self, max_size: int = 10): self.max_size = max_size self.cache = {} self.order = [] def _hash_audio(self, audio_path: str) -> str: """Create hash from audio file for cache key""" with open(audio_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest()[:16] def get(self, audio_path: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: key = self._hash_audio(audio_path) if key in self.cache: # Move to end (most recently used) self.order.remove(key) self.order.append(key) return self.cache[key] return None def set(self, audio_path: str, latents: Tuple[torch.Tensor, torch.Tensor]): key = self._hash_audio(audio_path) # Evict oldest if at capacity if len(self.cache) >= self.max_size and key not in self.cache: oldest = self.order.pop(0) del self.cache[oldest] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() self.cache[key] = latents if key not in self.order: self.order.append(key) def clear(self): self.cache.clear() self.order.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() speaker_cache = SpeakerCache(max_size=MAX_CACHE_SIZE) # ============== Core Functions ============== @torch.inference_mode() def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]: """Get speaker conditioning with caching""" # Check cache first cached = speaker_cache.get(speaker_wav) if cached is not None: logger.info("Using cached speaker latents") return cached logger.info("Computing speaker latents...") gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=speaker_wav, gpt_cond_len=config.gpt_cond_len if hasattr(config, 'gpt_cond_len') else 6, gpt_cond_chunk_len=config.gpt_cond_chunk_len if hasattr(config, 'gpt_cond_chunk_len') else 3, max_ref_length=config.max_ref_len if hasattr(config, 'max_ref_len') else 30, sound_norm_refs=config.sound_norm_refs if hasattr(config, 'sound_norm_refs') else False, ) # Move to correct device and dtype if USE_FP16 and device == "cuda": gpt_cond_latent = gpt_cond_latent.half() speaker_embedding = speaker_embedding.half() speaker_cache.set(speaker_wav, (gpt_cond_latent, speaker_embedding)) return gpt_cond_latent, speaker_embedding @torch.inference_mode() def synthesize( text: str, speaker_wav: str, language: str, temperature: float = 0.65, top_p: float = 0.85, top_k: int = 50, repetition_penalty: float = 5.0, length_penalty: float = 1.0, speed: float = 1.0 ) -> Optional[Tuple[int, np.ndarray]]: """Standard synthesis with optimizations""" if not text.strip(): return None if not speaker_wav: return None try: gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav) out = model.inference( text=text, language=language, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, length_penalty=length_penalty, speed=speed, enable_text_splitting=True ) wav = np.array(out["wav"]) sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000 return (sample_rate, wav) except Exception as e: logger.error(f"Synthesis error: {e}") raise gr.Error(f"Synthesis failed: {str(e)}") @torch.inference_mode() def synthesize_streaming( text: str, speaker_wav: str, language: str, temperature: float = 0.65, top_p: float = 0.85, top_k: int = 50, repetition_penalty: float = 5.0, speed: float = 1.0 ): """Streaming synthesis for lower latency""" if not text.strip() or not speaker_wav: return try: gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav) chunks = model.inference_stream( text=text, language=language, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, speed=speed, stream_chunk_size=STREAMING_CHUNK_SIZE, enable_text_splitting=True ) sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000 for chunk in chunks: if chunk is not None: yield (sample_rate, chunk.cpu().numpy().squeeze()) except Exception as e: logger.error(f"Streaming error: {e}") raise gr.Error(f"Streaming failed: {str(e)}") def clear_cache(): """Clear speaker cache and exhaustively free CUDA memory""" speaker_cache.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() return "Cache and VRAM cleared!" # ============== Gradio Interface ============== LANGUAGES = [ ("English", "en"), ("Spanish", "es"), ("French", "fr"), ("German", "de"), ("Italian", "it"), ("Portuguese", "pt"), ("Polish", "pl"), ("Turkish", "tr"), ("Russian", "ru"), ("Dutch", "nl"), ("Czech", "cs"), ("Arabic", "ar"), ("Chinese", "zh-cn"), ("Japanese", "ja"), ("Hungarian", "hu"), ("Korean", "ko"), ("Hindi", "hi"), ] css = """ .generate-btn { background: linear-gradient(90deg, #4CAF50 0%, #45a049 100%) !important; border: none !important; } .generate-btn:hover { background: linear-gradient(90deg, #45a049 0%, #3d8b40 100%) !important; } footer {visibility: hidden} """ with gr.Blocks(title="🐸 XTTSv2 TTS", css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🐸 XTTSv2 Text-to-Speech High-quality multilingual voice cloning with optimized inference. Upload a reference audio (6+ seconds recommended) and enter your text. """) with gr.Tabs(): # Standard Tab with gr.TabItem("🎙️ Standard"): with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="Text to synthesize", placeholder="Enter text here...", lines=4, max_lines=10 ) speaker_wav = gr.Audio( label="Reference Audio", type="filepath", sources=["upload", "microphone"] ) language = gr.Dropdown( choices=LANGUAGES, value="en", label="Language" ) with gr.Accordion("Advanced Settings", open=False): temperature = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P") top_k = gr.Slider(1, 100, value=50, step=1, label="Top K") repetition_penalty = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty") length_penalty = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Length Penalty") speed = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed") generate_btn = gr.Button("🔊 Generate Speech", variant="primary", elem_classes=["generate-btn"]) with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Speech", type="numpy") generate_btn.click( fn=synthesize, inputs=[text_input, speaker_wav, language, temperature, top_p, top_k, repetition_penalty, length_penalty, speed], outputs=audio_output ) # Streaming Tab with gr.TabItem("⚡ Streaming (Low Latency)"): with gr.Row(): with gr.Column(scale=1): text_input_stream = gr.Textbox( label="Text to synthesize", placeholder="Enter text here...", lines=4 ) speaker_wav_stream = gr.Audio( label="Reference Audio", type="filepath", sources=["upload", "microphone"] ) language_stream = gr.Dropdown( choices=LANGUAGES, value="en", label="Language" ) with gr.Accordion("Advanced Settings", open=False): temp_stream = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature") top_p_stream = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P") top_k_stream = gr.Slider(1, 100, value=50, step=1, label="Top K") rep_pen_stream = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty") speed_stream = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed") stream_btn = gr.Button("⚡ Stream Speech", variant="primary") with gr.Column(scale=1): audio_output_stream = gr.Audio(label="Streaming Output", streaming=True, autoplay=True) stream_btn.click( fn=synthesize_streaming, inputs=[text_input_stream, speaker_wav_stream, language_stream, temp_stream, top_p_stream, top_k_stream, rep_pen_stream, speed_stream], outputs=audio_output_stream ) # Settings Tab with gr.TabItem("⚙️ Settings"): gr.Markdown(f""" ### Current Configuration - **Device**: {device} - **DeepSpeed**: {'Enabled' if USE_DEEPSPEED else 'Disabled'} - **FP16**: {'Enabled' if USE_FP16 else 'Disabled'} - **torch.compile**: {'Enabled' if USE_TORCH_COMPILE else 'Disabled'} - **Max Cached Speakers**: {MAX_CACHE_SIZE} """) clear_cache_btn = gr.Button("🗑️ Clear Speaker Cache") cache_status = gr.Textbox(label="Status", interactive=False) clear_cache_btn.click(fn=clear_cache, outputs=cache_status) gr.Markdown(""" --- **Tips for best results:** - Use clean reference audio with minimal background noise - 6-30 seconds of reference audio works best - Match the language of your text to your reference audio for best quality """) if __name__ == "__main__": demo.queue(max_size=10).launch( server_name="0.0.0.0", server_port=7860, show_error=True )