Upload 2 files
Browse files- app.py +23 -29
- requirements.txt +10 -7
app.py
CHANGED
|
@@ -37,18 +37,18 @@ STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20"))
|
|
| 37 |
# ============== Model Loading ==============
|
| 38 |
def load_model():
|
| 39 |
"""Load XTTSv2 with all optimizations"""
|
| 40 |
-
|
| 41 |
from TTS.tts.configs.xtts_config import XttsConfig
|
| 42 |
from TTS.tts.models.xtts import Xtts
|
|
|
|
| 43 |
|
| 44 |
logger.info("Loading XTTSv2 model...")
|
| 45 |
|
| 46 |
-
# Check if local model exists
|
| 47 |
local_config = os.path.join(MODEL_PATH, "config.json")
|
|
|
|
| 48 |
|
| 49 |
if os.path.exists(local_config):
|
| 50 |
-
# Load local/fine-tuned model
|
| 51 |
-
logger.info(f"Loading local model from {MODEL_PATH}")
|
| 52 |
config = XttsConfig()
|
| 53 |
config.load_json(local_config)
|
| 54 |
model = Xtts.init_from_config(config)
|
|
@@ -59,39 +59,29 @@ def load_model():
|
|
| 59 |
use_deepspeed=USE_DEEPSPEED
|
| 60 |
)
|
| 61 |
else:
|
| 62 |
-
#
|
| 63 |
-
logger.info("Loading default coqui/XTTS-v2
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
model
|
| 70 |
|
| 71 |
-
# FP16 optimization
|
| 72 |
if USE_FP16 and device == "cuda":
|
| 73 |
logger.info("Enabling FP16 inference...")
|
| 74 |
model.half()
|
| 75 |
-
# Keep some layers in FP32 for stability
|
| 76 |
-
if hasattr(model, 'gpt'):
|
| 77 |
-
model.gpt.float()
|
| 78 |
|
| 79 |
-
# torch.compile for
|
| 80 |
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
| 81 |
-
logger.info("Applying torch.compile()...")
|
| 82 |
try:
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
mode="reduce-overhead",
|
| 87 |
-
fullgraph=False
|
| 88 |
-
)
|
| 89 |
except Exception as e:
|
| 90 |
-
logger.warning(f"torch.compile failed: {e}")
|
| 91 |
|
| 92 |
model.eval()
|
| 93 |
-
logger.info(f"Model loaded on {device}")
|
| 94 |
-
|
| 95 |
return model, config, device
|
| 96 |
|
| 97 |
# Global model instance
|
|
@@ -264,9 +254,13 @@ def synthesize_streaming(
|
|
| 264 |
|
| 265 |
|
| 266 |
def clear_cache():
|
| 267 |
-
"""Clear speaker cache and CUDA memory"""
|
| 268 |
speaker_cache.clear()
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
# ============== Gradio Interface ==============
|
|
|
|
| 37 |
# ============== Model Loading ==============
|
| 38 |
def load_model():
|
| 39 |
"""Load XTTSv2 with all optimizations"""
|
| 40 |
+
# Import inside function to prevent early CUDA initialization
|
| 41 |
from TTS.tts.configs.xtts_config import XttsConfig
|
| 42 |
from TTS.tts.models.xtts import Xtts
|
| 43 |
+
from TTS.api import TTS
|
| 44 |
|
| 45 |
logger.info("Loading XTTSv2 model...")
|
| 46 |
|
| 47 |
+
# Check if local model exists
|
| 48 |
local_config = os.path.join(MODEL_PATH, "config.json")
|
| 49 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
|
| 51 |
if os.path.exists(local_config):
|
|
|
|
|
|
|
| 52 |
config = XttsConfig()
|
| 53 |
config.load_json(local_config)
|
| 54 |
model = Xtts.init_from_config(config)
|
|
|
|
| 59 |
use_deepspeed=USE_DEEPSPEED
|
| 60 |
)
|
| 61 |
else:
|
| 62 |
+
# Reverting to the high-level API for Hub loads as it handles weights better
|
| 63 |
+
logger.info("Loading default coqui/XTTS-v2 from Hub...")
|
| 64 |
+
# We use the synthesizer directly to access the model object for optimizations
|
| 65 |
+
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
| 66 |
+
model = tts.synthesizer.tts_model
|
| 67 |
+
config = tts.synthesizer.tts_config
|
| 68 |
+
|
| 69 |
+
model.to(device)
|
| 70 |
|
|
|
|
| 71 |
if USE_FP16 and device == "cuda":
|
| 72 |
logger.info("Enabling FP16 inference...")
|
| 73 |
model.half()
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
# Logic for torch.compile (requires Triton for some features)
|
| 76 |
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
|
|
|
| 77 |
try:
|
| 78 |
+
# We only compile the GPT part as it's the bottleneck
|
| 79 |
+
model.gpt = torch.compile(model.gpt, mode="reduce-overhead")
|
| 80 |
+
logger.info("GPT compiled successfully.")
|
|
|
|
|
|
|
|
|
|
| 81 |
except Exception as e:
|
| 82 |
+
logger.warning(f"torch.compile failed, skipping: {e}")
|
| 83 |
|
| 84 |
model.eval()
|
|
|
|
|
|
|
| 85 |
return model, config, device
|
| 86 |
|
| 87 |
# Global model instance
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
def clear_cache():
|
| 257 |
+
"""Clear speaker cache and exhaustively free CUDA memory"""
|
| 258 |
speaker_cache.clear()
|
| 259 |
+
gc.collect()
|
| 260 |
+
if torch.cuda.is_available():
|
| 261 |
+
torch.cuda.empty_cache()
|
| 262 |
+
torch.cuda.synchronize()
|
| 263 |
+
return "Cache and VRAM cleared!"
|
| 264 |
|
| 265 |
|
| 266 |
# ============== Gradio Interface ==============
|
requirements.txt
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
-
# Use the
|
| 2 |
-
|
| 3 |
-
coqui-tts>=0.25.3
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
|
|
|
|
|
| 1 |
+
# Use the latest stable Gradio to fix the JSON Schema / additionalProperties bug
|
| 2 |
+
gradio>=5.9.1
|
|
|
|
| 3 |
|
| 4 |
+
# The PyPI package for Coqui TTS is 'tts'
|
| 5 |
+
# We pin versions of transformers/tokenizers because XTTS is sensitive to their breaking changes
|
| 6 |
+
tts==0.22.0
|
| 7 |
+
transformers<=4.43.3
|
| 8 |
+
tokenizers<=0.19.1
|
| 9 |
|
| 10 |
+
# High-performance inference
|
| 11 |
+
deepspeed>=0.14.0
|
| 12 |
+
huggingface_hub
|