|
|
""" |
|
|
Optimized XTTSv2 Hugging Face Space |
|
|
- DeepSpeed acceleration |
|
|
- FP16 inference |
|
|
- torch.compile() optimization |
|
|
- Speaker latent caching |
|
|
- Streaming inference |
|
|
- Memory optimization |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import gc |
|
|
import hashlib |
|
|
import tempfile |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from functools import lru_cache |
|
|
from typing import Optional, Tuple |
|
|
import logging |
|
|
import functools |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
torch.load = functools.partial(torch.load, weights_only=False) |
|
|
|
|
|
|
|
|
os.environ["COQUI_TOS_AGREED"] = "1" |
|
|
|
|
|
|
|
|
MODEL_PATH = os.environ.get("MODEL_PATH", "./model") |
|
|
USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", "false").lower() == "true" |
|
|
USE_FP16 = os.environ.get("USE_FP16", "true").lower() == "true" |
|
|
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "false").lower() == "true" |
|
|
MAX_CACHE_SIZE = int(os.environ.get("MAX_CACHE_SIZE", "10")) |
|
|
STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20")) |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load XTTSv2 with all optimizations""" |
|
|
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
|
from TTS.tts.models.xtts import Xtts |
|
|
from TTS.api import TTS |
|
|
|
|
|
logger.info("Loading XTTSv2 model...") |
|
|
|
|
|
|
|
|
local_config = os.path.join(MODEL_PATH, "config.json") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
if os.path.exists(local_config): |
|
|
config = XttsConfig() |
|
|
config.load_json(local_config) |
|
|
model = Xtts.init_from_config(config) |
|
|
model.load_checkpoint( |
|
|
config, |
|
|
checkpoint_dir=MODEL_PATH, |
|
|
eval=True, |
|
|
use_deepspeed=USE_DEEPSPEED |
|
|
) |
|
|
else: |
|
|
|
|
|
logger.info("Loading default coqui/XTTS-v2 from Hub...") |
|
|
|
|
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) |
|
|
model = tts.synthesizer.tts_model |
|
|
config = tts.synthesizer.tts_config |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
if USE_FP16 and device == "cuda": |
|
|
logger.info("Enabling FP16 inference...") |
|
|
model.half() |
|
|
|
|
|
|
|
|
if USE_TORCH_COMPILE and hasattr(torch, 'compile'): |
|
|
try: |
|
|
|
|
|
model.gpt = torch.compile(model.gpt, mode="reduce-overhead") |
|
|
logger.info("GPT compiled successfully.") |
|
|
except Exception as e: |
|
|
logger.warning(f"torch.compile failed, skipping: {e}") |
|
|
|
|
|
model.eval() |
|
|
return model, config, device |
|
|
|
|
|
|
|
|
model, config, device = load_model() |
|
|
|
|
|
|
|
|
class SpeakerCache: |
|
|
"""LRU cache for speaker embeddings with hash-based keys""" |
|
|
|
|
|
def __init__(self, max_size: int = 10): |
|
|
self.max_size = max_size |
|
|
self.cache = {} |
|
|
self.order = [] |
|
|
|
|
|
def _hash_audio(self, audio_path: str) -> str: |
|
|
"""Create hash from audio file for cache key""" |
|
|
with open(audio_path, 'rb') as f: |
|
|
return hashlib.md5(f.read()).hexdigest()[:16] |
|
|
|
|
|
def get(self, audio_path: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: |
|
|
key = self._hash_audio(audio_path) |
|
|
if key in self.cache: |
|
|
|
|
|
self.order.remove(key) |
|
|
self.order.append(key) |
|
|
return self.cache[key] |
|
|
return None |
|
|
|
|
|
def set(self, audio_path: str, latents: Tuple[torch.Tensor, torch.Tensor]): |
|
|
key = self._hash_audio(audio_path) |
|
|
|
|
|
|
|
|
if len(self.cache) >= self.max_size and key not in self.cache: |
|
|
oldest = self.order.pop(0) |
|
|
del self.cache[oldest] |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.cache[key] = latents |
|
|
if key not in self.order: |
|
|
self.order.append(key) |
|
|
|
|
|
def clear(self): |
|
|
self.cache.clear() |
|
|
self.order.clear() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
speaker_cache = SpeakerCache(max_size=MAX_CACHE_SIZE) |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def get_speaker_latents(speaker_wav: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Get speaker conditioning with caching""" |
|
|
|
|
|
|
|
|
cached = speaker_cache.get(speaker_wav) |
|
|
if cached is not None: |
|
|
logger.info("Using cached speaker latents") |
|
|
return cached |
|
|
|
|
|
logger.info("Computing speaker latents...") |
|
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( |
|
|
audio_path=speaker_wav, |
|
|
gpt_cond_len=config.gpt_cond_len if hasattr(config, 'gpt_cond_len') else 6, |
|
|
gpt_cond_chunk_len=config.gpt_cond_chunk_len if hasattr(config, 'gpt_cond_chunk_len') else 3, |
|
|
max_ref_length=config.max_ref_len if hasattr(config, 'max_ref_len') else 30, |
|
|
sound_norm_refs=config.sound_norm_refs if hasattr(config, 'sound_norm_refs') else False, |
|
|
) |
|
|
|
|
|
|
|
|
if USE_FP16 and device == "cuda": |
|
|
gpt_cond_latent = gpt_cond_latent.half() |
|
|
speaker_embedding = speaker_embedding.half() |
|
|
|
|
|
speaker_cache.set(speaker_wav, (gpt_cond_latent, speaker_embedding)) |
|
|
return gpt_cond_latent, speaker_embedding |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def synthesize( |
|
|
text: str, |
|
|
speaker_wav: str, |
|
|
language: str, |
|
|
temperature: float = 0.65, |
|
|
top_p: float = 0.85, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 5.0, |
|
|
length_penalty: float = 1.0, |
|
|
speed: float = 1.0 |
|
|
) -> Optional[Tuple[int, np.ndarray]]: |
|
|
"""Standard synthesis with optimizations""" |
|
|
|
|
|
if not text.strip(): |
|
|
return None |
|
|
if not speaker_wav: |
|
|
return None |
|
|
|
|
|
try: |
|
|
gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav) |
|
|
|
|
|
out = model.inference( |
|
|
text=text, |
|
|
language=language, |
|
|
gpt_cond_latent=gpt_cond_latent, |
|
|
speaker_embedding=speaker_embedding, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
length_penalty=length_penalty, |
|
|
speed=speed, |
|
|
enable_text_splitting=True |
|
|
) |
|
|
|
|
|
wav = np.array(out["wav"]) |
|
|
sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000 |
|
|
|
|
|
return (sample_rate, wav) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Synthesis error: {e}") |
|
|
raise gr.Error(f"Synthesis failed: {str(e)}") |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def synthesize_streaming( |
|
|
text: str, |
|
|
speaker_wav: str, |
|
|
language: str, |
|
|
temperature: float = 0.65, |
|
|
top_p: float = 0.85, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 5.0, |
|
|
speed: float = 1.0 |
|
|
): |
|
|
"""Streaming synthesis for lower latency""" |
|
|
|
|
|
if not text.strip() or not speaker_wav: |
|
|
return |
|
|
|
|
|
try: |
|
|
gpt_cond_latent, speaker_embedding = get_speaker_latents(speaker_wav) |
|
|
|
|
|
chunks = model.inference_stream( |
|
|
text=text, |
|
|
language=language, |
|
|
gpt_cond_latent=gpt_cond_latent, |
|
|
speaker_embedding=speaker_embedding, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
speed=speed, |
|
|
stream_chunk_size=STREAMING_CHUNK_SIZE, |
|
|
enable_text_splitting=True |
|
|
) |
|
|
|
|
|
sample_rate = config.audio.output_sample_rate if hasattr(config.audio, 'output_sample_rate') else 24000 |
|
|
|
|
|
for chunk in chunks: |
|
|
if chunk is not None: |
|
|
yield (sample_rate, chunk.cpu().numpy().squeeze()) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Streaming error: {e}") |
|
|
raise gr.Error(f"Streaming failed: {str(e)}") |
|
|
|
|
|
|
|
|
def clear_cache(): |
|
|
"""Clear speaker cache and exhaustively free CUDA memory""" |
|
|
speaker_cache.clear() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
return "Cache and VRAM cleared!" |
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = [ |
|
|
("English", "en"), |
|
|
("Spanish", "es"), |
|
|
("French", "fr"), |
|
|
("German", "de"), |
|
|
("Italian", "it"), |
|
|
("Portuguese", "pt"), |
|
|
("Polish", "pl"), |
|
|
("Turkish", "tr"), |
|
|
("Russian", "ru"), |
|
|
("Dutch", "nl"), |
|
|
("Czech", "cs"), |
|
|
("Arabic", "ar"), |
|
|
("Chinese", "zh-cn"), |
|
|
("Japanese", "ja"), |
|
|
("Hungarian", "hu"), |
|
|
("Korean", "ko"), |
|
|
("Hindi", "hi"), |
|
|
] |
|
|
|
|
|
css = """ |
|
|
.generate-btn { |
|
|
background: linear-gradient(90deg, #4CAF50 0%, #45a049 100%) !important; |
|
|
border: none !important; |
|
|
} |
|
|
.generate-btn:hover { |
|
|
background: linear-gradient(90deg, #45a049 0%, #3d8b40 100%) !important; |
|
|
} |
|
|
footer {visibility: hidden} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="๐ธ XTTSv2 TTS", css=css, theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐ธ XTTSv2 Text-to-Speech |
|
|
|
|
|
High-quality multilingual voice cloning with optimized inference. |
|
|
Upload a reference audio (6+ seconds recommended) and enter your text. |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("๐๏ธ Standard"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
text_input = gr.Textbox( |
|
|
label="Text to synthesize", |
|
|
placeholder="Enter text here...", |
|
|
lines=4, |
|
|
max_lines=10 |
|
|
) |
|
|
speaker_wav = gr.Audio( |
|
|
label="Reference Audio", |
|
|
type="filepath", |
|
|
sources=["upload", "microphone"] |
|
|
) |
|
|
language = gr.Dropdown( |
|
|
choices=LANGUAGES, |
|
|
value="en", |
|
|
label="Language" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
temperature = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature") |
|
|
top_p = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P") |
|
|
top_k = gr.Slider(1, 100, value=50, step=1, label="Top K") |
|
|
repetition_penalty = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty") |
|
|
length_penalty = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Length Penalty") |
|
|
speed = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed") |
|
|
|
|
|
generate_btn = gr.Button("๐ Generate Speech", variant="primary", elem_classes=["generate-btn"]) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
audio_output = gr.Audio(label="Generated Speech", type="numpy") |
|
|
|
|
|
generate_btn.click( |
|
|
fn=synthesize, |
|
|
inputs=[text_input, speaker_wav, language, temperature, top_p, top_k, repetition_penalty, length_penalty, speed], |
|
|
outputs=audio_output |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("โก Streaming (Low Latency)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
text_input_stream = gr.Textbox( |
|
|
label="Text to synthesize", |
|
|
placeholder="Enter text here...", |
|
|
lines=4 |
|
|
) |
|
|
speaker_wav_stream = gr.Audio( |
|
|
label="Reference Audio", |
|
|
type="filepath", |
|
|
sources=["upload", "microphone"] |
|
|
) |
|
|
language_stream = gr.Dropdown( |
|
|
choices=LANGUAGES, |
|
|
value="en", |
|
|
label="Language" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
temp_stream = gr.Slider(0.1, 1.0, value=0.65, step=0.05, label="Temperature") |
|
|
top_p_stream = gr.Slider(0.1, 1.0, value=0.85, step=0.05, label="Top P") |
|
|
top_k_stream = gr.Slider(1, 100, value=50, step=1, label="Top K") |
|
|
rep_pen_stream = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Repetition Penalty") |
|
|
speed_stream = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speed") |
|
|
|
|
|
stream_btn = gr.Button("โก Stream Speech", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
audio_output_stream = gr.Audio(label="Streaming Output", streaming=True, autoplay=True) |
|
|
|
|
|
stream_btn.click( |
|
|
fn=synthesize_streaming, |
|
|
inputs=[text_input_stream, speaker_wav_stream, language_stream, temp_stream, top_p_stream, top_k_stream, rep_pen_stream, speed_stream], |
|
|
outputs=audio_output_stream |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("โ๏ธ Settings"): |
|
|
gr.Markdown(f""" |
|
|
### Current Configuration |
|
|
- **Device**: {device} |
|
|
- **DeepSpeed**: {'Enabled' if USE_DEEPSPEED else 'Disabled'} |
|
|
- **FP16**: {'Enabled' if USE_FP16 else 'Disabled'} |
|
|
- **torch.compile**: {'Enabled' if USE_TORCH_COMPILE else 'Disabled'} |
|
|
- **Max Cached Speakers**: {MAX_CACHE_SIZE} |
|
|
""") |
|
|
|
|
|
clear_cache_btn = gr.Button("๐๏ธ Clear Speaker Cache") |
|
|
cache_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
clear_cache_btn.click(fn=clear_cache, outputs=cache_status) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Tips for best results:** |
|
|
- Use clean reference audio with minimal background noise |
|
|
- 6-30 seconds of reference audio works best |
|
|
- Match the language of your text to your reference audio for best quality |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10).launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |
|
|
|