Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,29 +2,23 @@ import os
|
|
| 2 |
import tempfile
|
| 3 |
import gradio as gr
|
| 4 |
from huggingface_hub import snapshot_download
|
| 5 |
-
|
| 6 |
-
# If torch is optional for you, you can keep this minimal
|
| 7 |
import torch
|
| 8 |
-
|
| 9 |
-
# Import after deps are installed (handled by requirements.txt)
|
| 10 |
from indextts.infer import IndexTTS
|
| 11 |
|
| 12 |
-
|
| 13 |
CHECKPOINTS_DIR = os.path.abspath("checkpoints")
|
| 14 |
|
| 15 |
def load_model():
|
| 16 |
"""
|
| 17 |
-
Download model weights (if needed) and initialize IndexTTS once.
|
| 18 |
-
Avoids the 'checkpoints/checkpoints' double-path bug by using the exact
|
| 19 |
-
path returned from snapshot_download.
|
| 20 |
"""
|
| 21 |
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 22 |
|
| 23 |
-
# Download
|
| 24 |
repo_path = snapshot_download(
|
| 25 |
repo_id="mlx-community/IndexTTS",
|
| 26 |
local_dir=CHECKPOINTS_DIR,
|
| 27 |
-
local_dir_use_symlinks=False,
|
| 28 |
allow_patterns=[
|
| 29 |
"config.yaml",
|
| 30 |
"bpe.model",
|
|
@@ -36,7 +30,14 @@ def load_model():
|
|
| 36 |
],
|
| 37 |
)
|
| 38 |
|
| 39 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 41 |
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 42 |
try:
|
|
@@ -44,12 +45,11 @@ def load_model():
|
|
| 44 |
except Exception:
|
| 45 |
pass
|
| 46 |
|
| 47 |
-
# Initialize IndexTTS
|
| 48 |
-
tts = IndexTTS(model_dir=repo_path, cfg_path=
|
| 49 |
return tts
|
| 50 |
|
| 51 |
-
|
| 52 |
-
# Global singleton (loaded once on Space startup)
|
| 53 |
_tts = None
|
| 54 |
def get_tts():
|
| 55 |
global _tts
|
|
@@ -57,37 +57,32 @@ def get_tts():
|
|
| 57 |
_tts = load_model()
|
| 58 |
return _tts
|
| 59 |
|
| 60 |
-
|
| 61 |
def synthesize(voice_path, text):
|
| 62 |
"""
|
| 63 |
Gradio inference function.
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
Returns
|
| 67 |
"""
|
| 68 |
if not voice_path or not os.path.exists(voice_path):
|
| 69 |
raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
|
| 70 |
-
|
| 71 |
if not text or not text.strip():
|
| 72 |
-
raise gr.Error("Please enter
|
| 73 |
|
| 74 |
tts = get_tts()
|
| 75 |
|
| 76 |
-
#
|
| 77 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 78 |
out_path = tmp.name
|
| 79 |
|
| 80 |
-
# Minimal call; IndexTTS handles normalization/phonemization internally.
|
| 81 |
-
# You can add extra kwargs if the library exposes them (e.g., speed, seed).
|
| 82 |
tts.infer(voice_path, text.strip(), out_path)
|
| 83 |
-
|
| 84 |
return out_path
|
| 85 |
|
| 86 |
-
|
| 87 |
title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
|
| 88 |
description = """
|
| 89 |
Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
|
| 90 |
-
This Space runs **IndexTTS** in CPU mode by default, so first run may take a
|
| 91 |
"""
|
| 92 |
|
| 93 |
with gr.Blocks() as demo:
|
|
@@ -95,30 +90,21 @@ with gr.Blocks() as demo:
|
|
| 95 |
|
| 96 |
with gr.Row():
|
| 97 |
with gr.Column():
|
| 98 |
-
voice = gr.Audio(
|
| 99 |
-
|
| 100 |
-
type="filepath",
|
| 101 |
-
label="Reference Voice (WAV preferred)"
|
| 102 |
-
)
|
| 103 |
-
text = gr.Textbox(
|
| 104 |
-
label="Text to Synthesize",
|
| 105 |
-
placeholder="Hello, how are you?",
|
| 106 |
-
lines=3
|
| 107 |
-
)
|
| 108 |
btn = gr.Button("Generate Speech")
|
| 109 |
-
|
| 110 |
with gr.Column():
|
| 111 |
audio_out = gr.Audio(label="Output Audio", type="filepath")
|
| 112 |
log = gr.Markdown("")
|
| 113 |
|
| 114 |
btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
|
| 115 |
|
| 116 |
-
# Optional
|
| 117 |
def _startup():
|
| 118 |
try:
|
| 119 |
get_tts()
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
-
# Don't crash the Space if warmup fails; show a note in Logs.
|
| 122 |
print("Warmup failed:", e)
|
| 123 |
|
| 124 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import gradio as gr
|
| 4 |
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 5 |
import torch
|
|
|
|
|
|
|
| 6 |
from indextts.infer import IndexTTS
|
| 7 |
|
| 8 |
+
# Directory to store downloaded model files
|
| 9 |
CHECKPOINTS_DIR = os.path.abspath("checkpoints")
|
| 10 |
|
| 11 |
def load_model():
|
| 12 |
"""
|
| 13 |
+
Download IndexTTS model weights (if needed) and initialize IndexTTS once.
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 16 |
|
| 17 |
+
# Download weights from HF Hub
|
| 18 |
repo_path = snapshot_download(
|
| 19 |
repo_id="mlx-community/IndexTTS",
|
| 20 |
local_dir=CHECKPOINTS_DIR,
|
| 21 |
+
local_dir_use_symlinks=False,
|
| 22 |
allow_patterns=[
|
| 23 |
"config.yaml",
|
| 24 |
"bpe.model",
|
|
|
|
| 30 |
],
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# Debug: verify files
|
| 34 |
+
print("Downloaded files:", os.listdir(repo_path))
|
| 35 |
+
|
| 36 |
+
cfg_file = os.path.join(repo_path, "config.yaml")
|
| 37 |
+
if not os.path.exists(cfg_file):
|
| 38 |
+
raise FileNotFoundError(f"Cannot find config.yaml in {repo_path}. Check repo contents.")
|
| 39 |
+
|
| 40 |
+
# Limit CPU threads for Spaces
|
| 41 |
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 42 |
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 43 |
try:
|
|
|
|
| 45 |
except Exception:
|
| 46 |
pass
|
| 47 |
|
| 48 |
+
# Initialize IndexTTS
|
| 49 |
+
tts = IndexTTS(model_dir=repo_path, cfg_path=cfg_file)
|
| 50 |
return tts
|
| 51 |
|
| 52 |
+
# Global singleton for TTS
|
|
|
|
| 53 |
_tts = None
|
| 54 |
def get_tts():
|
| 55 |
global _tts
|
|
|
|
| 57 |
_tts = load_model()
|
| 58 |
return _tts
|
| 59 |
|
|
|
|
| 60 |
def synthesize(voice_path, text):
|
| 61 |
"""
|
| 62 |
Gradio inference function.
|
| 63 |
+
voice_path: path to reference voice (WAV recommended)
|
| 64 |
+
text: string to synthesize
|
| 65 |
+
Returns: path to output WAV
|
| 66 |
"""
|
| 67 |
if not voice_path or not os.path.exists(voice_path):
|
| 68 |
raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
|
|
|
|
| 69 |
if not text or not text.strip():
|
| 70 |
+
raise gr.Error("Please enter text to synthesize.")
|
| 71 |
|
| 72 |
tts = get_tts()
|
| 73 |
|
| 74 |
+
# Temporary output WAV
|
| 75 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 76 |
out_path = tmp.name
|
| 77 |
|
|
|
|
|
|
|
| 78 |
tts.infer(voice_path, text.strip(), out_path)
|
|
|
|
| 79 |
return out_path
|
| 80 |
|
| 81 |
+
# Gradio UI
|
| 82 |
title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
|
| 83 |
description = """
|
| 84 |
Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
|
| 85 |
+
This Space runs **IndexTTS** in CPU mode by default, so first run may take a while to warm up.
|
| 86 |
"""
|
| 87 |
|
| 88 |
with gr.Blocks() as demo:
|
|
|
|
| 90 |
|
| 91 |
with gr.Row():
|
| 92 |
with gr.Column():
|
| 93 |
+
voice = gr.Audio(sources=["upload"], type="filepath", label="Reference Voice (WAV preferred)")
|
| 94 |
+
text = gr.Textbox(label="Text to Synthesize", placeholder="Hello, how are you?", lines=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
btn = gr.Button("Generate Speech")
|
|
|
|
| 96 |
with gr.Column():
|
| 97 |
audio_out = gr.Audio(label="Output Audio", type="filepath")
|
| 98 |
log = gr.Markdown("")
|
| 99 |
|
| 100 |
btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
|
| 101 |
|
| 102 |
+
# Optional startup preload
|
| 103 |
def _startup():
|
| 104 |
try:
|
| 105 |
get_tts()
|
| 106 |
+
print("TTS model loaded successfully at startup.")
|
| 107 |
except Exception as e:
|
|
|
|
| 108 |
print("Warmup failed:", e)
|
| 109 |
|
| 110 |
if __name__ == "__main__":
|