import os import sys import tempfile import gradio as gr import soundfile as sf import torch from huggingface_hub import snapshot_download from transformers import MarianMTModel, MarianTokenizer, pipeline # -------------------------- # Download IndexTTS repo from Hugging Face # -------------------------- CHECKPOINTS_DIR = os.path.abspath("checkpoints") os.makedirs(CHECKPOINTS_DIR, exist_ok=True) repo_path = snapshot_download( repo_id="mlx-community/IndexTTS", # Correct repo local_dir=CHECKPOINTS_DIR, local_dir_use_symlinks=False, allow_patterns=[ "config.yaml", "bpe.model", "unigram_12000.vocab", "gpt.pth", "bigvgan_generator.pth", "bigvgan_discriminator.pth", "dvae.pth", ], ) sys.path.append(repo_path) from indextts.infer import IndexTTS # -------------------------- # Initialize TTS safely # -------------------------- _tts = None def get_tts(): global _tts if _tts is None: try: _tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml")) except FileNotFoundError as e: print("Error loading IndexTTS:", e) raise gr.Error("IndexTTS model files not found!") return _tts # Limit CPU threads (important for Spaces) torch.set_num_threads(1) os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" # -------------------------- # Translation models # -------------------------- language_models = { "Spanish → English": "Helsinki-NLP/opus-mt-es-en", "English → Spanish": "Helsinki-NLP/opus-mt-en-es" } current_model_name = None tokenizer = None model = None def load_translation_model(lang_pair): global current_model_name, tokenizer, model if language_models[lang_pair] != current_model_name: current_model_name = language_models[lang_pair] tokenizer = MarianTokenizer.from_pretrained(current_model_name) model = MarianMTModel.from_pretrained(current_model_name) # -------------------------- # Speech-to-text (ASR) # -------------------------- asr = pipeline("automatic-speech-recognition", model="openai/whisper-small") # -------------------------- # Core functions # -------------------------- def text_to_speech(text, ref_voice_path): """ Convert text to speech using IndexTTS. Returns a temporary WAV file path. """ tts = get_tts() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: out_path = tmp.name tts.infer(ref_voice_path, text, out_path) return out_path def translate_with_voice(audio, lang_pair, ref_voice): # Handle Gradio sending numpy array + sample_rate if isinstance(audio, tuple): audio_path = audio[0] # (filepath, sample_rate) or (sample_rate, array) else: audio_path = audio # 1) Speech to text text_input = asr(audio_path)["text"] # 2) Translation load_translation_model(lang_pair) inputs = tokenizer(text_input, return_tensors="pt", padding=True) translated_ids = model.generate(**inputs) translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) # 3) Text to speech out_wav_path = text_to_speech(translated_text, ref_voice) return translated_text, out_wav_path # -------------------------- # Gradio UI # -------------------------- title = "🗣 Voice-Cloned Translator (English ↔ Spanish)" description = """ Upload a short **reference voice** (5–10s, clean speech works best) and speak into the microphone. This Space uses **IndexTTS** for zero-shot voice cloning and **Hugging Face models** for translation. """ with gr.Blocks() as demo: gr.Markdown(f"# {title}\n{description}") with gr.Row(): with gr.Column(): audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙 Speak") lang_dropdown = gr.Dropdown(list(language_models.keys()), label="🌍 Target Language", value="Spanish → English") ref_voice_input = gr.Audio(sources=["upload"], type="filepath", label="🎧 Reference Voice (5–10s)") btn = gr.Button("Translate & Speak") with gr.Column(): text_output = gr.Textbox(label="Translated Text") audio_output = gr.Audio(label="🔊 Translated Audio", type="filepath") btn.click( fn=translate_with_voice, inputs=[audio_input, lang_dropdown, ref_voice_input], outputs=[text_output, audio_output] ) # Preload TTS on startup def _startup(): try: get_tts() except Exception as e: print("Warmup failed:", e) if __name__ == "__main__": _startup() demo.launch(server_name="0.0.0.0", server_port=7860)