Translator / app.py
drixo's picture
Update app.py
863b347 verified
import os
import sys
import tempfile
import gradio as gr
import soundfile as sf
import torch
from huggingface_hub import snapshot_download
from transformers import MarianMTModel, MarianTokenizer, pipeline
# --------------------------
# Download IndexTTS repo from Hugging Face
# --------------------------
CHECKPOINTS_DIR = os.path.abspath("checkpoints")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
repo_path = snapshot_download(
repo_id="mlx-community/IndexTTS", # Correct repo
local_dir=CHECKPOINTS_DIR,
local_dir_use_symlinks=False,
allow_patterns=[
"config.yaml",
"bpe.model",
"unigram_12000.vocab",
"gpt.pth",
"bigvgan_generator.pth",
"bigvgan_discriminator.pth",
"dvae.pth",
],
)
sys.path.append(repo_path)
from indextts.infer import IndexTTS
# --------------------------
# Initialize TTS safely
# --------------------------
_tts = None
def get_tts():
global _tts
if _tts is None:
try:
_tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
except FileNotFoundError as e:
print("Error loading IndexTTS:", e)
raise gr.Error("IndexTTS model files not found!")
return _tts
# Limit CPU threads (important for Spaces)
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
# --------------------------
# Translation models
# --------------------------
language_models = {
"Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
"English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
}
current_model_name = None
tokenizer = None
model = None
def load_translation_model(lang_pair):
global current_model_name, tokenizer, model
if language_models[lang_pair] != current_model_name:
current_model_name = language_models[lang_pair]
tokenizer = MarianTokenizer.from_pretrained(current_model_name)
model = MarianMTModel.from_pretrained(current_model_name)
# --------------------------
# Speech-to-text (ASR)
# --------------------------
asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
# --------------------------
# Core functions
# --------------------------
def text_to_speech(text, ref_voice_path):
"""
Convert text to speech using IndexTTS.
Returns a temporary WAV file path.
"""
tts = get_tts()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
out_path = tmp.name
tts.infer(ref_voice_path, text, out_path)
return out_path
def translate_with_voice(audio, lang_pair, ref_voice):
# Handle Gradio sending numpy array + sample_rate
if isinstance(audio, tuple):
audio_path = audio[0] # (filepath, sample_rate) or (sample_rate, array)
else:
audio_path = audio
# 1) Speech to text
text_input = asr(audio_path)["text"]
# 2) Translation
load_translation_model(lang_pair)
inputs = tokenizer(text_input, return_tensors="pt", padding=True)
translated_ids = model.generate(**inputs)
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
# 3) Text to speech
out_wav_path = text_to_speech(translated_text, ref_voice)
return translated_text, out_wav_path
# --------------------------
# Gradio UI
# --------------------------
title = "πŸ—£ Voice-Cloned Translator (English ↔ Spanish)"
description = """
Upload a short **reference voice** (5–10s, clean speech works best) and speak into the microphone.
This Space uses **IndexTTS** for zero-shot voice cloning and **Hugging Face models** for translation.
"""
with gr.Blocks() as demo:
gr.Markdown(f"# {title}\n{description}")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="πŸŽ™ Speak")
lang_dropdown = gr.Dropdown(list(language_models.keys()), label="🌍 Target Language", value="Spanish β†’ English")
ref_voice_input = gr.Audio(sources=["upload"], type="filepath", label="🎧 Reference Voice (5–10s)")
btn = gr.Button("Translate & Speak")
with gr.Column():
text_output = gr.Textbox(label="Translated Text")
audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="filepath")
btn.click(
fn=translate_with_voice,
inputs=[audio_input, lang_dropdown, ref_voice_input],
outputs=[text_output, audio_output]
)
# Preload TTS on startup
def _startup():
try:
get_tts()
except Exception as e:
print("Warmup failed:", e)
if __name__ == "__main__":
_startup()
demo.launch(server_name="0.0.0.0", server_port=7860)