drixo commited on
Commit
057f29d
Β·
verified Β·
1 Parent(s): 2a191ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -100
app.py CHANGED
@@ -1,113 +1,84 @@
1
- import os
2
- import tempfile
3
  import gradio as gr
4
- from huggingface_hub import snapshot_download
5
  import torch
6
- from indextts.infer import IndexTTS
7
-
8
- # Directory to store downloaded model files
9
- CHECKPOINTS_DIR = os.path.abspath("checkpoints")
10
-
11
- def load_model():
12
- """
13
- Download IndexTTS model weights (if needed) and initialize IndexTTS once.
14
- """
15
- os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
16
-
17
- # Download weights from HF Hub
18
- repo_path = snapshot_download(
19
- repo_id="mlx-community/IndexTTS",
20
- local_dir=CHECKPOINTS_DIR,
21
- local_dir_use_symlinks=False,
22
- allow_patterns=[
23
- "config.yaml",
24
- "bpe.model",
25
- "unigram_12000.vocab",
26
- "gpt.pth",
27
- "bigvgan_generator.pth",
28
- "bigvgan_discriminator.pth",
29
- "dvae.pth",
30
- ],
31
- )
32
-
33
- # Debug: verify files
34
- print("Downloaded files:", os.listdir(repo_path))
35
-
36
- cfg_file = os.path.join(repo_path, "config.yaml")
37
- if not os.path.exists(cfg_file):
38
- raise FileNotFoundError(f"Cannot find config.yaml in {repo_path}. Check repo contents.")
39
-
40
- # Limit CPU threads for Spaces
41
- os.environ.setdefault("OMP_NUM_THREADS", "1")
42
- os.environ.setdefault("MKL_NUM_THREADS", "1")
43
- try:
44
- torch.set_num_threads(1)
45
- except Exception:
46
- pass
47
-
48
- # Initialize IndexTTS
49
- tts = IndexTTS(model_dir=repo_path, cfg_path=cfg_file)
50
- return tts
51
-
52
- # Global singleton for TTS
53
- _tts = None
54
- def get_tts():
55
- global _tts
56
- if _tts is None:
57
- _tts = load_model()
58
- return _tts
59
-
60
- def synthesize(voice_path, text):
61
- """
62
- Gradio inference function.
63
- voice_path: path to reference voice (WAV recommended)
64
- text: string to synthesize
65
- Returns: path to output WAV
66
- """
67
- if not voice_path or not os.path.exists(voice_path):
68
- raise gr.Error("Please upload a short reference voice clip (WAV recommended).")
69
- if not text or not text.strip():
70
- raise gr.Error("Please enter text to synthesize.")
71
-
72
- tts = get_tts()
73
 
74
- # Temporary output WAV
75
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
76
- out_path = tmp.name
 
 
77
 
78
- tts.infer(voice_path, text.strip(), out_path)
79
- return out_path
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Gradio UI
82
- title = "IndexTTS – Zero-shot Voice Cloning (HF Space)"
83
- description = """
84
- Upload a short **reference voice** (5–10s, clean speech works best) and enter text.
85
- This Space runs **IndexTTS** in CPU mode by default, so first run may take a while to warm up.
86
- """
87
-
88
  with gr.Blocks() as demo:
89
- gr.Markdown(f"# {title}\n{description}")
90
 
91
  with gr.Row():
92
  with gr.Column():
93
- voice = gr.Audio(sources=["upload"], type="filepath", label="Reference Voice (WAV preferred)")
94
- text = gr.Textbox(label="Text to Synthesize", placeholder="Hello, how are you?", lines=3)
95
- btn = gr.Button("Generate Speech")
96
- with gr.Column():
97
- audio_out = gr.Audio(label="Output Audio", type="filepath")
98
- log = gr.Markdown("")
99
-
100
- btn.click(fn=synthesize, inputs=[voice, text], outputs=[audio_out])
101
 
102
- # Optional startup preload
103
- def _startup():
104
- try:
105
- get_tts()
106
- print("TTS model loaded successfully at startup.")
107
- except Exception as e:
108
- print("Warmup failed:", e)
109
 
110
- if __name__ == "__main__":
111
- _startup()
112
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
113
 
 
 
 
 
1
  import gradio as gr
2
+ import soundfile as sf
3
  import torch
4
+ import sys, os
5
+ from transformers import MarianMTModel, MarianTokenizer, pipeline
6
+ from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # --------------------------
9
+ # Download Index-TTS repo from Hugging Face
10
+ # --------------------------
11
+ repo_path = snapshot_download("IndexTeam/Index-TTS", local_dir="checkpoints")
12
+ sys.path.append(repo_path)
13
 
14
+ from indextts.infer import IndexTTS
 
15
 
16
+ # Init TTS
17
+ tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
18
+
19
+ # --------------------------
20
+ # Translation models
21
+ # --------------------------
22
+ language_models = {
23
+ "Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
24
+ "English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
25
+ }
26
+ current_model_name = language_models["Spanish β†’ English"]
27
+ tokenizer = MarianTokenizer.from_pretrained(current_model_name)
28
+ model = MarianMTModel.from_pretrained(current_model_name)
29
+
30
+ # Speech-to-text (ASR)
31
+ asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
32
+
33
+ # --------------------------
34
+ # Functions
35
+ # --------------------------
36
+ def text_to_speech(text, ref_voice):
37
+ output_path = "output.wav"
38
+ tts.infer(ref_voice, text, output_path)
39
+ data, samplerate = sf.read(output_path)
40
+ return samplerate, data
41
+
42
+ def translate_with_voice(audio, lang_pair, ref_voice):
43
+ # 1) Speech to text
44
+ text_input = asr(audio)["text"]
45
+
46
+ # 2) Translation
47
+ global tokenizer, model, current_model_name
48
+ if language_models[lang_pair] != current_model_name:
49
+ current_model_name = language_models[lang_pair]
50
+ tokenizer = MarianTokenizer.from_pretrained(current_model_name)
51
+ model = MarianMTModel.from_pretrained(current_model_name)
52
+
53
+ inputs = tokenizer(text_input, return_tensors="pt", padding=True)
54
+ translated = model.generate(**inputs)
55
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
56
+
57
+ # 3) Text to speech
58
+ sr, audio_array = text_to_speech(translated_text, ref_voice)
59
+ return translated_text, (sr, audio_array)
60
+
61
+ # --------------------------
62
  # Gradio UI
63
+ # --------------------------
 
 
 
 
 
64
  with gr.Blocks() as demo:
65
+ gr.Markdown("## πŸ—£ Voice-Cloned Translator (English ↔ Spanish)")
66
 
67
  with gr.Row():
68
  with gr.Column():
69
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="πŸŽ™ Speak")
70
+ lang_dropdown = gr.Dropdown(list(language_models.keys()), label="🌍 Target Language", value="Spanish β†’ English")
71
+ ref_voice_input = gr.Audio(sources=["upload"], type="filepath", label="🎧 Reference Voice (5–10s)")
72
+ btn = gr.Button("Translate & Speak")
 
 
 
 
73
 
74
+ with gr.Column():
75
+ text_output = gr.Textbox(label="Translated Text")
76
+ audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="numpy")
 
 
 
 
77
 
78
+ btn.click(
79
+ fn=translate_with_voice,
80
+ inputs=[audio_input, lang_dropdown, ref_voice_input],
81
+ outputs=[text_output, audio_output]
82
+ )
83
 
84
+ demo.launch()