fosters commited on
Commit
e590dfc
·
verified ·
1 Parent(s): 78ed7b9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +23 -29
  2. requirements.txt +10 -7
app.py CHANGED
@@ -37,18 +37,18 @@ STREAMING_CHUNK_SIZE = int(os.environ.get("STREAMING_CHUNK_SIZE", "20"))
37
  # ============== Model Loading ==============
38
  def load_model():
39
  """Load XTTSv2 with all optimizations"""
40
- from TTS.api import TTS
41
  from TTS.tts.configs.xtts_config import XttsConfig
42
  from TTS.tts.models.xtts import Xtts
 
43
 
44
  logger.info("Loading XTTSv2 model...")
45
 
46
- # Check if local model exists, otherwise use default from HF Hub
47
  local_config = os.path.join(MODEL_PATH, "config.json")
 
48
 
49
  if os.path.exists(local_config):
50
- # Load local/fine-tuned model
51
- logger.info(f"Loading local model from {MODEL_PATH}")
52
  config = XttsConfig()
53
  config.load_json(local_config)
54
  model = Xtts.init_from_config(config)
@@ -59,39 +59,29 @@ def load_model():
59
  use_deepspeed=USE_DEEPSPEED
60
  )
61
  else:
62
- # Load default XTTS-v2 from Hugging Face Hub via TTS API
63
- logger.info("Loading default coqui/XTTS-v2 model from Hugging Face Hub...")
64
- tts_api = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=torch.cuda.is_available())
65
- model = tts_api.synthesizer.tts_model
66
- config = tts_api.synthesizer.tts_config
67
-
68
- device = "cuda" if torch.cuda.is_available() else "cpu"
69
- model = model.to(device)
70
 
71
- # FP16 optimization
72
  if USE_FP16 and device == "cuda":
73
  logger.info("Enabling FP16 inference...")
74
  model.half()
75
- # Keep some layers in FP32 for stability
76
- if hasattr(model, 'gpt'):
77
- model.gpt.float()
78
 
79
- # torch.compile for PyTorch 2.0+
80
  if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
81
- logger.info("Applying torch.compile()...")
82
  try:
83
- if hasattr(model, 'hifigan_decoder'):
84
- model.hifigan_decoder = torch.compile(
85
- model.hifigan_decoder,
86
- mode="reduce-overhead",
87
- fullgraph=False
88
- )
89
  except Exception as e:
90
- logger.warning(f"torch.compile failed: {e}")
91
 
92
  model.eval()
93
- logger.info(f"Model loaded on {device}")
94
-
95
  return model, config, device
96
 
97
  # Global model instance
@@ -264,9 +254,13 @@ def synthesize_streaming(
264
 
265
 
266
  def clear_cache():
267
- """Clear speaker cache and CUDA memory"""
268
  speaker_cache.clear()
269
- return "Cache cleared!"
 
 
 
 
270
 
271
 
272
  # ============== Gradio Interface ==============
 
37
  # ============== Model Loading ==============
38
  def load_model():
39
  """Load XTTSv2 with all optimizations"""
40
+ # Import inside function to prevent early CUDA initialization
41
  from TTS.tts.configs.xtts_config import XttsConfig
42
  from TTS.tts.models.xtts import Xtts
43
+ from TTS.api import TTS
44
 
45
  logger.info("Loading XTTSv2 model...")
46
 
47
+ # Check if local model exists
48
  local_config = os.path.join(MODEL_PATH, "config.json")
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
  if os.path.exists(local_config):
 
 
52
  config = XttsConfig()
53
  config.load_json(local_config)
54
  model = Xtts.init_from_config(config)
 
59
  use_deepspeed=USE_DEEPSPEED
60
  )
61
  else:
62
+ # Reverting to the high-level API for Hub loads as it handles weights better
63
+ logger.info("Loading default coqui/XTTS-v2 from Hub...")
64
+ # We use the synthesizer directly to access the model object for optimizations
65
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
66
+ model = tts.synthesizer.tts_model
67
+ config = tts.synthesizer.tts_config
68
+
69
+ model.to(device)
70
 
 
71
  if USE_FP16 and device == "cuda":
72
  logger.info("Enabling FP16 inference...")
73
  model.half()
 
 
 
74
 
75
+ # Logic for torch.compile (requires Triton for some features)
76
  if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
 
77
  try:
78
+ # We only compile the GPT part as it's the bottleneck
79
+ model.gpt = torch.compile(model.gpt, mode="reduce-overhead")
80
+ logger.info("GPT compiled successfully.")
 
 
 
81
  except Exception as e:
82
+ logger.warning(f"torch.compile failed, skipping: {e}")
83
 
84
  model.eval()
 
 
85
  return model, config, device
86
 
87
  # Global model instance
 
254
 
255
 
256
  def clear_cache():
257
+ """Clear speaker cache and exhaustively free CUDA memory"""
258
  speaker_cache.clear()
259
+ gc.collect()
260
+ if torch.cuda.is_available():
261
+ torch.cuda.empty_cache()
262
+ torch.cuda.synchronize()
263
+ return "Cache and VRAM cleared!"
264
 
265
 
266
  # ============== Gradio Interface ==============
requirements.txt CHANGED
@@ -1,9 +1,12 @@
1
- # Use the maintained Idiap fork instead of abandoned coqui-ai/TTS
2
- # This fixes transformers compatibility and is actively maintained
3
- coqui-tts>=0.25.3
4
 
5
- # Gradio UI - pin to stable version (5.6.0 has JSON schema bug)
6
- gradio==5.5.0
 
 
 
7
 
8
- # Hugging Face
9
- huggingface_hub
 
 
1
+ # Use the latest stable Gradio to fix the JSON Schema / additionalProperties bug
2
+ gradio>=5.9.1
 
3
 
4
+ # The PyPI package for Coqui TTS is 'tts'
5
+ # We pin versions of transformers/tokenizers because XTTS is sensitive to their breaking changes
6
+ tts==0.22.0
7
+ transformers<=4.43.3
8
+ tokenizers<=0.19.1
9
 
10
+ # High-performance inference
11
+ deepspeed>=0.14.0
12
+ huggingface_hub