drixo commited on
Commit
863b347
Β·
verified Β·
1 Parent(s): 057f29d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -29
app.py CHANGED
@@ -1,20 +1,54 @@
 
 
 
1
  import gradio as gr
2
  import soundfile as sf
3
  import torch
4
- import sys, os
5
- from transformers import MarianMTModel, MarianTokenizer, pipeline
6
  from huggingface_hub import snapshot_download
 
7
 
8
  # --------------------------
9
- # Download Index-TTS repo from Hugging Face
10
  # --------------------------
11
- repo_path = snapshot_download("IndexTeam/Index-TTS", local_dir="checkpoints")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  sys.path.append(repo_path)
13
 
14
  from indextts.infer import IndexTTS
15
 
16
- # Init TTS
17
- tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # --------------------------
20
  # Translation models
@@ -23,46 +57,68 @@ language_models = {
23
  "Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
24
  "English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
25
  }
26
- current_model_name = language_models["Spanish β†’ English"]
27
- tokenizer = MarianTokenizer.from_pretrained(current_model_name)
28
- model = MarianMTModel.from_pretrained(current_model_name)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Speech-to-text (ASR)
 
31
  asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
32
 
33
  # --------------------------
34
- # Functions
35
  # --------------------------
36
- def text_to_speech(text, ref_voice):
37
- output_path = "output.wav"
38
- tts.infer(ref_voice, text, output_path)
39
- data, samplerate = sf.read(output_path)
40
- return samplerate, data
 
 
 
 
 
41
 
42
  def translate_with_voice(audio, lang_pair, ref_voice):
 
 
 
 
 
 
43
  # 1) Speech to text
44
- text_input = asr(audio)["text"]
45
 
46
  # 2) Translation
47
- global tokenizer, model, current_model_name
48
- if language_models[lang_pair] != current_model_name:
49
- current_model_name = language_models[lang_pair]
50
- tokenizer = MarianTokenizer.from_pretrained(current_model_name)
51
- model = MarianMTModel.from_pretrained(current_model_name)
52
-
53
  inputs = tokenizer(text_input, return_tensors="pt", padding=True)
54
- translated = model.generate(**inputs)
55
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
56
 
57
  # 3) Text to speech
58
- sr, audio_array = text_to_speech(translated_text, ref_voice)
59
- return translated_text, (sr, audio_array)
60
 
61
  # --------------------------
62
  # Gradio UI
63
  # --------------------------
 
 
 
 
 
 
64
  with gr.Blocks() as demo:
65
- gr.Markdown("## πŸ—£ Voice-Cloned Translator (English ↔ Spanish)")
66
 
67
  with gr.Row():
68
  with gr.Column():
@@ -73,7 +129,7 @@ with gr.Blocks() as demo:
73
 
74
  with gr.Column():
75
  text_output = gr.Textbox(label="Translated Text")
76
- audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="numpy")
77
 
78
  btn.click(
79
  fn=translate_with_voice,
@@ -81,4 +137,14 @@ with gr.Blocks() as demo:
81
  outputs=[text_output, audio_output]
82
  )
83
 
84
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
  import gradio as gr
5
  import soundfile as sf
6
  import torch
 
 
7
  from huggingface_hub import snapshot_download
8
+ from transformers import MarianMTModel, MarianTokenizer, pipeline
9
 
10
  # --------------------------
11
+ # Download IndexTTS repo from Hugging Face
12
  # --------------------------
13
+ CHECKPOINTS_DIR = os.path.abspath("checkpoints")
14
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
15
+
16
+ repo_path = snapshot_download(
17
+ repo_id="mlx-community/IndexTTS", # Correct repo
18
+ local_dir=CHECKPOINTS_DIR,
19
+ local_dir_use_symlinks=False,
20
+ allow_patterns=[
21
+ "config.yaml",
22
+ "bpe.model",
23
+ "unigram_12000.vocab",
24
+ "gpt.pth",
25
+ "bigvgan_generator.pth",
26
+ "bigvgan_discriminator.pth",
27
+ "dvae.pth",
28
+ ],
29
+ )
30
  sys.path.append(repo_path)
31
 
32
  from indextts.infer import IndexTTS
33
 
34
+ # --------------------------
35
+ # Initialize TTS safely
36
+ # --------------------------
37
+ _tts = None
38
+ def get_tts():
39
+ global _tts
40
+ if _tts is None:
41
+ try:
42
+ _tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
43
+ except FileNotFoundError as e:
44
+ print("Error loading IndexTTS:", e)
45
+ raise gr.Error("IndexTTS model files not found!")
46
+ return _tts
47
+
48
+ # Limit CPU threads (important for Spaces)
49
+ torch.set_num_threads(1)
50
+ os.environ["OMP_NUM_THREADS"] = "1"
51
+ os.environ["MKL_NUM_THREADS"] = "1"
52
 
53
  # --------------------------
54
  # Translation models
 
57
  "Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
58
  "English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
59
  }
 
 
 
60
 
61
+ current_model_name = None
62
+ tokenizer = None
63
+ model = None
64
+
65
+ def load_translation_model(lang_pair):
66
+ global current_model_name, tokenizer, model
67
+ if language_models[lang_pair] != current_model_name:
68
+ current_model_name = language_models[lang_pair]
69
+ tokenizer = MarianTokenizer.from_pretrained(current_model_name)
70
+ model = MarianMTModel.from_pretrained(current_model_name)
71
+
72
+ # --------------------------
73
  # Speech-to-text (ASR)
74
+ # --------------------------
75
  asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
76
 
77
  # --------------------------
78
+ # Core functions
79
  # --------------------------
80
+ def text_to_speech(text, ref_voice_path):
81
+ """
82
+ Convert text to speech using IndexTTS.
83
+ Returns a temporary WAV file path.
84
+ """
85
+ tts = get_tts()
86
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
87
+ out_path = tmp.name
88
+ tts.infer(ref_voice_path, text, out_path)
89
+ return out_path
90
 
91
  def translate_with_voice(audio, lang_pair, ref_voice):
92
+ # Handle Gradio sending numpy array + sample_rate
93
+ if isinstance(audio, tuple):
94
+ audio_path = audio[0] # (filepath, sample_rate) or (sample_rate, array)
95
+ else:
96
+ audio_path = audio
97
+
98
  # 1) Speech to text
99
+ text_input = asr(audio_path)["text"]
100
 
101
  # 2) Translation
102
+ load_translation_model(lang_pair)
 
 
 
 
 
103
  inputs = tokenizer(text_input, return_tensors="pt", padding=True)
104
+ translated_ids = model.generate(**inputs)
105
+ translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
106
 
107
  # 3) Text to speech
108
+ out_wav_path = text_to_speech(translated_text, ref_voice)
109
+ return translated_text, out_wav_path
110
 
111
  # --------------------------
112
  # Gradio UI
113
  # --------------------------
114
+ title = "πŸ—£ Voice-Cloned Translator (English ↔ Spanish)"
115
+ description = """
116
+ Upload a short **reference voice** (5–10s, clean speech works best) and speak into the microphone.
117
+ This Space uses **IndexTTS** for zero-shot voice cloning and **Hugging Face models** for translation.
118
+ """
119
+
120
  with gr.Blocks() as demo:
121
+ gr.Markdown(f"# {title}\n{description}")
122
 
123
  with gr.Row():
124
  with gr.Column():
 
129
 
130
  with gr.Column():
131
  text_output = gr.Textbox(label="Translated Text")
132
+ audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="filepath")
133
 
134
  btn.click(
135
  fn=translate_with_voice,
 
137
  outputs=[text_output, audio_output]
138
  )
139
 
140
+ # Preload TTS on startup
141
+ def _startup():
142
+ try:
143
+ get_tts()
144
+ except Exception as e:
145
+ print("Warmup failed:", e)
146
+
147
+ if __name__ == "__main__":
148
+ _startup()
149
+ demo.launch(server_name="0.0.0.0", server_port=7860)
150
+