drixo commited on
Commit
b0d995c
·
verified ·
1 Parent(s): aa4d216

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -40
app.py CHANGED
@@ -2,29 +2,23 @@ import os
2
  import tempfile
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
-
6
- # If torch is optional for you, you can keep this minimal
7
  import torch
8
-
9
- # Import after deps are installed (handled by requirements.txt)
10
  from indextts.infer import IndexTTS
11
 
12
-
13
  CHECKPOINTS_DIR = os.path.abspath("checkpoints")
14
 
15
  def load_model():
16
  """
17
- Download model weights (if needed) and initialize IndexTTS once.
18
- Avoids the 'checkpoints/checkpoints' double-path bug by using the exact
19
- path returned from snapshot_download.
20
  """
21
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
22
 
23
- # Download to a fixed directory; do NOT prefix this path again later.
24
  repo_path = snapshot_download(
25
  repo_id="mlx-community/IndexTTS",
26
  local_dir=CHECKPOINTS_DIR,
27
- local_dir_use_symlinks=False, # ensures real files (safer in Spaces)
28
  allow_patterns=[
29
  "config.yaml",
30
  "bpe.model",
@@ -36,7 +30,14 @@ def load_model():
36
  ],
37
  )
38
 
39
- # Optional: keep CPU stable in Spaces and prevent over-threading
 
 
 
 
 
 
 
40
  os.environ.setdefault("OMP_NUM_THREADS", "1")
41
  os.environ.setdefault("MKL_NUM_THREADS", "1")
42
  try:
@@ -44,12 +45,11 @@ def load_model():
44
  except Exception:
45
  pass
46
 
47
- # Initialize IndexTTS. IMPORTANT: pass repo_path directly.
48
- tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
49
  return tts
50
 
51
-
52
- # Global singleton (loaded once on Space startup)
53
  _tts = None
54
  def get_tts():
55
  global _tts
@@ -57,37 +57,32 @@ def get_tts():
57
  _tts = load_model()
58
  return _tts
59
 
60
-
61
  def synthesize(voice_path, text):
62
  """
63
  Gradio inference function.
64
- - voice_path: path to uploaded reference voice (WAV strongly recommended)
65
- - text: the text to speak
66
- Returns (output_wav_path)
67
  """
68
  if not voice_path or not os.path.exists(voice_path):
69
  raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
70
-
71
  if not text or not text.strip():
72
- raise gr.Error("Please enter the text to speak.")
73
 
74
  tts = get_tts()
75
 
76
- # Write output to a temporary WAV file; Gradio will serve it.
77
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
78
  out_path = tmp.name
79
 
80
- # Minimal call; IndexTTS handles normalization/phonemization internally.
81
- # You can add extra kwargs if the library exposes them (e.g., speed, seed).
82
  tts.infer(voice_path, text.strip(), out_path)
83
-
84
  return out_path
85
 
86
-
87
  title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
88
  description = """
89
  Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
90
- This Space runs **IndexTTS** in CPU mode by default, so first run may take a bit to warm up.
91
  """
92
 
93
  with gr.Blocks() as demo:
@@ -95,30 +90,21 @@ with gr.Blocks() as demo:
95
 
96
  with gr.Row():
97
  with gr.Column():
98
- voice = gr.Audio(
99
- sources=["upload"],
100
- type="filepath",
101
- label="Reference Voice (WAV preferred)"
102
- )
103
- text = gr.Textbox(
104
- label="Text to Synthesize",
105
- placeholder="Hello, how are you?",
106
- lines=3
107
- )
108
  btn = gr.Button("Generate Speech")
109
-
110
  with gr.Column():
111
  audio_out = gr.Audio(label="Output Audio", type="filepath")
112
  log = gr.Markdown("")
113
 
114
  btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
115
 
116
- # Optional: pre-load at startup so first user call is faster
117
  def _startup():
118
  try:
119
  get_tts()
 
120
  except Exception as e:
121
- # Don't crash the Space if warmup fails; show a note in Logs.
122
  print("Warmup failed:", e)
123
 
124
  if __name__ == "__main__":
 
2
  import tempfile
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
 
 
5
  import torch
 
 
6
  from indextts.infer import IndexTTS
7
 
8
+ # Directory to store downloaded model files
9
  CHECKPOINTS_DIR = os.path.abspath("checkpoints")
10
 
11
  def load_model():
12
  """
13
+ Download IndexTTS model weights (if needed) and initialize IndexTTS once.
 
 
14
  """
15
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
16
 
17
+ # Download weights from HF Hub
18
  repo_path = snapshot_download(
19
  repo_id="mlx-community/IndexTTS",
20
  local_dir=CHECKPOINTS_DIR,
21
+ local_dir_use_symlinks=False,
22
  allow_patterns=[
23
  "config.yaml",
24
  "bpe.model",
 
30
  ],
31
  )
32
 
33
+ # Debug: verify files
34
+ print("Downloaded files:", os.listdir(repo_path))
35
+
36
+ cfg_file = os.path.join(repo_path, "config.yaml")
37
+ if not os.path.exists(cfg_file):
38
+ raise FileNotFoundError(f"Cannot find config.yaml in {repo_path}. Check repo contents.")
39
+
40
+ # Limit CPU threads for Spaces
41
  os.environ.setdefault("OMP_NUM_THREADS", "1")
42
  os.environ.setdefault("MKL_NUM_THREADS", "1")
43
  try:
 
45
  except Exception:
46
  pass
47
 
48
+ # Initialize IndexTTS
49
+ tts = IndexTTS(model_dir=repo_path, cfg_path=cfg_file)
50
  return tts
51
 
52
+ # Global singleton for TTS
 
53
  _tts = None
54
  def get_tts():
55
  global _tts
 
57
  _tts = load_model()
58
  return _tts
59
 
 
60
  def synthesize(voice_path, text):
61
  """
62
  Gradio inference function.
63
+ voice_path: path to reference voice (WAV recommended)
64
+ text: string to synthesize
65
+ Returns: path to output WAV
66
  """
67
  if not voice_path or not os.path.exists(voice_path):
68
  raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
 
69
  if not text or not text.strip():
70
+ raise gr.Error("Please enter text to synthesize.")
71
 
72
  tts = get_tts()
73
 
74
+ # Temporary output WAV
75
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
76
  out_path = tmp.name
77
 
 
 
78
  tts.infer(voice_path, text.strip(), out_path)
 
79
  return out_path
80
 
81
+ # Gradio UI
82
  title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
83
  description = """
84
  Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
85
+ This Space runs **IndexTTS** in CPU mode by default, so first run may take a while to warm up.
86
  """
87
 
88
  with gr.Blocks() as demo:
 
90
 
91
  with gr.Row():
92
  with gr.Column():
93
+ voice = gr.Audio(sources=["upload"], type="filepath", label="Reference Voice (WAV preferred)")
94
+ text = gr.Textbox(label="Text to Synthesize", placeholder="Hello, how are you?", lines=3)
 
 
 
 
 
 
 
 
95
  btn = gr.Button("Generate Speech")
 
96
  with gr.Column():
97
  audio_out = gr.Audio(label="Output Audio", type="filepath")
98
  log = gr.Markdown("")
99
 
100
  btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
101
 
102
+ # Optional startup preload
103
  def _startup():
104
  try:
105
  get_tts()
106
+ print("TTS model loaded successfully at startup.")
107
  except Exception as e:
 
108
  print("Warmup failed:", e)
109
 
110
  if __name__ == "__main__":