xttsv2 / app.py
fosters's picture
Upload 2 files
c157994 verified
"""
Optimized XTTSv2 Hugging Face Space
- DeepSpeed acceleration
- FP16 inference
- torch.compile() optimization
- Speaker latent caching
- Streaming inference
- Memory optimization
"""
import gradio as gr
import torch
import os
import gc
import hashlib
import tempfile
import numpy as np
from pathlib import Path
from functools import lru_cache
from typing import Optional, Tuple
import logging
import functools
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.load = functools.partial(torch.load, weights_only=False)
# Auto-accept Coqui TOS for non-interactive environments
os.environ["COQUI_TOS_AGREED"] = "1"
# ============== Configuration ==============
MODEL_PATH = os.environ.get("MODEL_PATH", "./model")
USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", "false").lower() == "true" # Disabled by default for stability
USE_FP16 = os.environ.get("USE_FP16", "true").lower() == "true"
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "false").lower() == "true" # Disabled by default for stability
MAX_CACHE_SIZE = int(os.environ.get("MAX_CACHE_SIZE", "10")) # Max cached speakers
STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20"))
# ============== Model Loading ==============
def load_model():
"""Load XTTSv2 with all optimizations"""
# Import inside function to prevent early CUDA initialization
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.api import TTS
logger.info("Loading XTTSv2 model...")
# Check if local model exists
local_config = os.path.join(MODEL_PATH, "config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
if os.path.exists(local_config):
config = XttsConfig()
config.load_json(local_config)
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_dir=MODEL_PATH,
eval=True,
use_deepspeed=USE_DEEPSPEED
)
else:
# Reverting to the high-level API for Hub loads as it handles weights better
logger.info("Loading default coqui/XTTS-v2 from Hub...")
# We use the synthesizer directly to access the model object for optimizations
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
model = tts.synthesizer.tts_model
config = tts.synthesizer.tts_config
model.to(device)
if USE_FP16 and device == "cuda":
logger.info("Enabling FP16 inference...")
model.half()
# Logic for torch.compile (requires Triton for some features)
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
try:
# We only compile the GPT part as it's the bottleneck
model.gpt = torch.compile(model.gpt, mode="reduce-overhead")
logger.info("GPT compiled successfully.")
except Exception as e:
logger.warning(f"torch.compile failed, skipping: {e}")
model.eval()
return model, config, device
# Global model instance
model, config, device = load_model()
# ============== Speaker Caching ==============
class SpeakerCache:
"""LRU cache for speaker embeddings with hash-based keys"""
def __init__(self, max_size: int = 10):
self.max_size = max_size
self.cache = {}
self.order = []
def _hash_audio(self, audio_path: str) -> str:
"""Create hash from audio file for cache key"""
with open(audio_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()[:16]
def get(self, audio_path: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
key = self._hash_audio(audio_path)
if key in self.cache:
# Move to end (most recently used)
self.order.remove(key)
self.order.append(key)
return self.cache[key]
return None
def set(self, audio_path: str, latents: Tuple[torch.Tensor, torch.Tensor]):
key = self._hash_audio(audio_path)
# Evict oldest if at capacity
if len(self.cache) >= self.max_size and key not in self.cache:
oldest = self.order.pop(0)
del self.cache[oldest]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.cache[key] = latents
if key not in self.order:
self.order.append(key)
def clear(self):
self.cache.clear()
self.order.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
speaker_cache = SpeakerCache(max_size=MAX_CACHE_SIZE)
# ============== Core Functions ==============
@torch.inference_mode()
def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get speaker conditioning with caching"""
# Check cache first
cached = speaker_cache.get(speaker_wav)
if cached is not None:
logger.info("Using cached speaker latents")
return cached
logger.info("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
audio_path=speaker_wav,
gpt_cond_len=config.gpt_cond_len if hasattr(config, 'gpt_cond_len') else 6,
gpt_cond_chunk_len=config.gpt_cond_chunk_len if hasattr(config, 'gpt_cond_chunk_len') else 3,
max_ref_length=config.max_ref_len if hasattr(config, 'max_ref_len') else 30,
sound_norm_refs=config.sound_norm_refs if hasattr(config, 'sound_norm_refs') else False,
)
# Move to correct device and dtype
if USE_FP16 and device == "cuda":
gpt_cond_latent = gpt_cond_latent.half()
speaker_embedding = speaker_embedding.half()
speaker_cache.set(speaker_wav, (gpt_cond_latent, speaker_embedding))
return gpt_cond_latent, speaker_embedding
@torch.inference_mode()
def synthesize(
text: str,
speaker_wav: str,
language: str,
temperature: float = 0.65,
top_p: float = 0.85,
top_k: int = 50,
repetition_penalty: float = 5.0,
length_penalty: float = 1.0,
speed: float = 1.0
) -> Optional[Tuple[int, np.ndarray]]:
"""Standard synthesis with optimizations"""
if not text.strip():
return None
if not speaker_wav:
return None
try:
gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav)
out = model.inference(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
speed=speed,
enable_text_splitting=True
)
wav = np.array(out["wav"])
sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000
return (sample_rate, wav)
except Exception as e:
logger.error(f"Synthesis error: {e}")
raise gr.Error(f"Synthesis failed: {str(e)}")
@torch.inference_mode()
def synthesize_streaming(
text: str,
speaker_wav: str,
language: str,
temperature: float = 0.65,
top_p: float = 0.85,
top_k: int = 50,
repetition_penalty: float = 5.0,
speed: float = 1.0
):
"""Streaming synthesis for lower latency"""
if not text.strip() or not speaker_wav:
return
try:
gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav)
chunks = model.inference_stream(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
speed=speed,
stream_chunk_size=STREAMING_CHUNK_SIZE,
enable_text_splitting=True
)
sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000
for chunk in chunks:
if chunk is not None:
yield (sample_rate, chunk.cpu().numpy().squeeze())
except Exception as e:
logger.error(f"Streaming error: {e}")
raise gr.Error(f"Streaming failed: {str(e)}")
def clear_cache():
"""Clear speaker cache and exhaustively free CUDA memory"""
speaker_cache.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
return "Cache and VRAM cleared!"
# ============== Gradio Interface ==============
LANGUAGES = [
("English", "en"),
("Spanish", "es"),
("French", "fr"),
("German", "de"),
("Italian", "it"),
("Portuguese", "pt"),
("Polish", "pl"),
("Turkish", "tr"),
("Russian", "ru"),
("Dutch", "nl"),
("Czech", "cs"),
("Arabic", "ar"),
("Chinese", "zh-cn"),
("Japanese", "ja"),
("Hungarian", "hu"),
("Korean", "ko"),
("Hindi", "hi"),
]
css = """
.generate-btn {
background: linear-gradient(90deg, #4CAF50 0%, #45a049 100%) !important;
border: none !important;
}
.generate-btn:hover {
background: linear-gradient(90deg, #45a049 0%, #3d8b40 100%) !important;
}
footer {visibility: hidden}
"""
with gr.Blocks(title="๐Ÿธ XTTSv2 TTS", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐Ÿธ XTTSv2 Text-to-Speech
High-quality multilingual voice cloning with optimized inference.
Upload a reference audio (6+ seconds recommended) and enter your text.
""")
with gr.Tabs():
# Standard Tab
with gr.TabItem("๐ŸŽ™๏ธ Standard"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Text to synthesize",
placeholder="Enter text here...",
lines=4,
max_lines=10
)
speaker_wav = gr.Audio(
label="Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
language = gr.Dropdown(
choices=LANGUAGES,
value="en",
label="Language"
)
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top K")
repetition_penalty = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty")
length_penalty = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Length Penalty")
speed = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed")
generate_btn = gr.Button("๐Ÿ”Š Generate Speech", variant="primary", elem_classes=["generate-btn"])
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Speech", type="numpy")
generate_btn.click(
fn=synthesize,
inputs=[text_input, speaker_wav, language, temperature, top_p, top_k, repetition_penalty, length_penalty, speed],
outputs=audio_output
)
# Streaming Tab
with gr.TabItem("โšก Streaming (Low Latency)"):
with gr.Row():
with gr.Column(scale=1):
text_input_stream = gr.Textbox(
label="Text to synthesize",
placeholder="Enter text here...",
lines=4
)
speaker_wav_stream = gr.Audio(
label="Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
language_stream = gr.Dropdown(
choices=LANGUAGES,
value="en",
label="Language"
)
with gr.Accordion("Advanced Settings", open=False):
temp_stream = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature")
top_p_stream = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P")
top_k_stream = gr.Slider(1, 100, value=50, step=1, label="Top K")
rep_pen_stream = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty")
speed_stream = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed")
stream_btn = gr.Button("โšก Stream Speech", variant="primary")
with gr.Column(scale=1):
audio_output_stream = gr.Audio(label="Streaming Output", streaming=True, autoplay=True)
stream_btn.click(
fn=synthesize_streaming,
inputs=[text_input_stream, speaker_wav_stream, language_stream, temp_stream, top_p_stream, top_k_stream, rep_pen_stream, speed_stream],
outputs=audio_output_stream
)
# Settings Tab
with gr.TabItem("โš™๏ธ Settings"):
gr.Markdown(f"""
### Current Configuration
- **Device**: {device}
- **DeepSpeed**: {'Enabled' if USE_DEEPSPEED else 'Disabled'}
- **FP16**: {'Enabled' if USE_FP16 else 'Disabled'}
- **torch.compile**: {'Enabled' if USE_TORCH_COMPILE else 'Disabled'}
- **Max Cached Speakers**: {MAX_CACHE_SIZE}
""")
clear_cache_btn = gr.Button("๐Ÿ—‘๏ธ Clear Speaker Cache")
cache_status = gr.Textbox(label="Status", interactive=False)
clear_cache_btn.click(fn=clear_cache, outputs=cache_status)
gr.Markdown("""
---
**Tips for best results:**
- Use clean reference audio with minimal background noise
- 6-30 seconds of reference audio works best
- Match the language of your text to your reference audio for best quality
""")
if __name__ == "__main__":
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)