drixo commited on
Commit
1f7b49e
Β·
verified Β·
1 Parent(s): 813795e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -70
app.py CHANGED
@@ -1,84 +1,127 @@
 
 
1
  import gradio as gr
2
- import soundfile as sf
3
- import torch
4
- import sys, os
5
- from transformers import MarianMTModel, MarianTokenizer, pipeline
6
  from huggingface_hub import snapshot_download
7
 
8
- # --------------------------
9
- # Download Index-TTS repo from Hugging Face
10
- # --------------------------
11
- repo_path = snapshot_download("IndexTeam/Index-TTS", local_dir="checkpoints")
12
- sys.path.append(repo_path)
13
 
 
14
  from indextts.infer import IndexTTS
15
 
16
- # Init TTS
17
- tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
18
-
19
- # --------------------------
20
- # Translation models
21
- # --------------------------
22
- language_models = {
23
- "Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
24
- "English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
25
- }
26
- current_model_name = language_models["Spanish β†’ English"]
27
- tokenizer = MarianTokenizer.from_pretrained(current_model_name)
28
- model = MarianMTModel.from_pretrained(current_model_name)
29
-
30
- # Speech-to-text (ASR)
31
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
32
-
33
- # --------------------------
34
- # Functions
35
- # --------------------------
36
- def text_to_speech(text, ref_voice):
37
- output_path = "output.wav"
38
- tts.infer(ref_voice, text, output_path)
39
- data, samplerate = sf.read(output_path)
40
- return samplerate, data
41
-
42
- def translate_with_voice(audio, lang_pair, ref_voice):
43
- # 1) Speech to text
44
- text_input = asr(audio)["text"]
45
-
46
- # 2) Translation
47
- global tokenizer, model, current_model_name
48
- if language_models[lang_pair] != current_model_name:
49
- current_model_name = language_models[lang_pair]
50
- tokenizer = MarianTokenizer.from_pretrained(current_model_name)
51
- model = MarianMTModel.from_pretrained(current_model_name)
52
-
53
- inputs = tokenizer(text_input, return_tensors="pt", padding=True)
54
- translated = model.generate(**inputs)
55
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
56
-
57
- # 3) Text to speech
58
- sr, audio_array = text_to_speech(translated_text, ref_voice)
59
- return translated_text, (sr, audio_array)
60
-
61
- # --------------------------
62
- # Gradio UI
63
- # --------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with gr.Blocks() as demo:
65
- gr.Markdown("## πŸ—£ Voice-Cloned Translator (English ↔ Spanish)")
66
 
67
  with gr.Row():
68
  with gr.Column():
69
- audio_input = gr.Audio(sources=["microphone"], type="filepath", label="πŸŽ™ Speak")
70
- lang_dropdown = gr.Dropdown(list(language_models.keys()), label="🌍 Target Language", value="Spanish β†’ English")
71
- ref_voice_input = gr.Audio(sources=["upload"], type="filepath", label="🎧 Reference Voice (5–10s)")
72
- btn = gr.Button("Translate & Speak")
 
 
 
 
 
 
 
73
 
74
  with gr.Column():
75
- text_output = gr.Textbox(label="Translated Text")
76
- audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="numpy")
77
 
78
- btn.click(
79
- fn=translate_with_voice,
80
- inputs=[audio_input, lang_dropdown, ref_voice_input],
81
- outputs=[text_output, audio_output]
82
- )
 
 
 
 
 
 
 
 
83
 
84
- demo.launch()
 
1
+ import os
2
+ import tempfile
3
  import gradio as gr
 
 
 
 
4
  from huggingface_hub import snapshot_download
5
 
6
+ # If torch is optional for you, you can keep this minimal
7
+ import torch
 
 
 
8
 
9
+ # Import after deps are installed (handled by requirements.txt)
10
  from indextts.infer import IndexTTS
11
 
12
+
13
+ CHECKPOINTS_DIR = os.path.abspath("checkpoints")
14
+
15
+ def load_model():
16
+ """
17
+ Download model weights (if needed) and initialize IndexTTS once.
18
+ Avoids the 'checkpoints/checkpoints' double-path bug by using the exact
19
+ path returned from snapshot_download.
20
+ """
21
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
22
+
23
+ # Download to a fixed directory; do NOT prefix this path again later.
24
+ repo_path = snapshot_download(
25
+ repo_id="mlx-community/IndexTTS",
26
+ local_dir=CHECKPOINTS_DIR,
27
+ local_dir_use_symlinks=False, # ensures real files (safer in Spaces)
28
+ allow_patterns=[
29
+ "config.yaml",
30
+ "bpe.model",
31
+ "unigram_12000.vocab",
32
+ "gpt.pth",
33
+ "bigvgan_generator.pth",
34
+ "bigvgan_discriminator.pth",
35
+ "dvae.pth",
36
+ ],
37
+ )
38
+
39
+ # Optional: keep CPU stable in Spaces and prevent over-threading
40
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
41
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
42
+ try:
43
+ torch.set_num_threads(1)
44
+ except Exception:
45
+ pass
46
+
47
+ # Initialize IndexTTS. IMPORTANT: pass repo_path directly.
48
+ tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
49
+ return tts
50
+
51
+
52
+ # Global singleton (loaded once on Space startup)
53
+ _tts = None
54
+ def get_tts():
55
+ global _tts
56
+ if _tts is None:
57
+ _tts = load_model()
58
+ return _tts
59
+
60
+
61
+ def synthesize(voice_path, text):
62
+ """
63
+ Gradio inference function.
64
+ - voice_path: path to uploaded reference voice (WAV strongly recommended)
65
+ - text: the text to speak
66
+ Returns (output_wav_path)
67
+ """
68
+ if not voice_path or not os.path.exists(voice_path):
69
+ raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
70
+
71
+ if not text or not text.strip():
72
+ raise gr.Error("Please enter the text to speak.")
73
+
74
+ tts = get_tts()
75
+
76
+ # Write output to a temporary WAV file; Gradio will serve it.
77
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
78
+ out_path = tmp.name
79
+
80
+ # Minimal call; IndexTTS handles normalization/phonemization internally.
81
+ # You can add extra kwargs if the library exposes them (e.g., speed, seed).
82
+ tts.infer(voice_path, text.strip(), out_path)
83
+
84
+ return out_path
85
+
86
+
87
+ title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
88
+ description = """
89
+ Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
90
+ This Space runs **IndexTTS** in CPU mode by default, so first run may take a bit to warm up.
91
+ """
92
+
93
  with gr.Blocks() as demo:
94
+ gr.Markdown(f"# {title}\n{description}")
95
 
96
  with gr.Row():
97
  with gr.Column():
98
+ voice = gr.Audio(
99
+ sources=["upload"],
100
+ type="filepath",
101
+ label="Reference Voice (WAV preferred)"
102
+ )
103
+ text = gr.Textbox(
104
+ label="Text to Synthesize",
105
+ placeholder="Hello, how are you?",
106
+ lines=3
107
+ )
108
+ btn = gr.Button("Generate Speech")
109
 
110
  with gr.Column():
111
+ audio_out = gr.Audio(label="Output Audio", type="filepath")
112
+ log = gr.Markdown("")
113
 
114
+ btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
115
+
116
+ # Optional: pre-load at startup so first user call is faster
117
+ def _startup():
118
+ try:
119
+ get_tts()
120
+ except Exception as e:
121
+ # Don't crash the Space if warmup fails; show a note in Logs.
122
+ print("Warmup failed:", e)
123
+
124
+ if __name__ == "__main__":
125
+ _startup()
126
+ demo.launch(server_name="0.0.0.0", server_port=7860)
127