File size: 4,713 Bytes
863b347
 
 
25f2399
057f29d
1f7b49e
057f29d
863b347
1f7b49e
057f29d
863b347
057f29d
863b347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
057f29d
1f7b49e
057f29d
1f7b49e
863b347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
057f29d
 
 
 
 
 
 
 
 
863b347
 
 
 
 
 
 
 
 
 
 
 
057f29d
863b347
057f29d
 
 
863b347
057f29d
863b347
 
 
 
 
 
 
 
 
 
057f29d
 
863b347
 
 
 
 
 
057f29d
863b347
057f29d
 
863b347
057f29d
863b347
 
057f29d
 
863b347
 
057f29d
 
b0d995c
057f29d
863b347
 
 
 
 
 
25f2399
863b347
d6458c9
25f2399
 
057f29d
 
 
 
1f7b49e
057f29d
 
863b347
1f7b49e
057f29d
 
 
 
 
027eeb4
863b347
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import sys
import tempfile
import gradio as gr
import soundfile as sf
import torch
from huggingface_hub import snapshot_download
from transformers import MarianMTModel, MarianTokenizer, pipeline

# --------------------------
# Download IndexTTS repo from Hugging Face
# --------------------------
CHECKPOINTS_DIR = os.path.abspath("checkpoints")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)

repo_path = snapshot_download(
    repo_id="mlx-community/IndexTTS",  # Correct repo
    local_dir=CHECKPOINTS_DIR,
    local_dir_use_symlinks=False,
    allow_patterns=[
        "config.yaml",
        "bpe.model",
        "unigram_12000.vocab",
        "gpt.pth",
        "bigvgan_generator.pth",
        "bigvgan_discriminator.pth",
        "dvae.pth",
    ],
)
sys.path.append(repo_path)

from indextts.infer import IndexTTS

# --------------------------
# Initialize TTS safely
# --------------------------
_tts = None
def get_tts():
    global _tts
    if _tts is None:
        try:
            _tts = IndexTTS(model_dir=repo_path, cfg_path=os.path.join(repo_path, "config.yaml"))
        except FileNotFoundError as e:
            print("Error loading IndexTTS:", e)
            raise gr.Error("IndexTTS model files not found!")
    return _tts

# Limit CPU threads (important for Spaces)
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# --------------------------
# Translation models
# --------------------------
language_models = {
    "Spanish β†’ English": "Helsinki-NLP/opus-mt-es-en",
    "English β†’ Spanish": "Helsinki-NLP/opus-mt-en-es"
}

current_model_name = None
tokenizer = None
model = None

def load_translation_model(lang_pair):
    global current_model_name, tokenizer, model
    if language_models[lang_pair] != current_model_name:
        current_model_name = language_models[lang_pair]
        tokenizer = MarianTokenizer.from_pretrained(current_model_name)
        model = MarianMTModel.from_pretrained(current_model_name)

# --------------------------
# Speech-to-text (ASR)
# --------------------------
asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")

# --------------------------
# Core functions
# --------------------------
def text_to_speech(text, ref_voice_path):
    """
    Convert text to speech using IndexTTS.
    Returns a temporary WAV file path.
    """
    tts = get_tts()
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        out_path = tmp.name
    tts.infer(ref_voice_path, text, out_path)
    return out_path

def translate_with_voice(audio, lang_pair, ref_voice):
    # Handle Gradio sending numpy array + sample_rate
    if isinstance(audio, tuple):
        audio_path = audio[0]  # (filepath, sample_rate) or (sample_rate, array)
    else:
        audio_path = audio

    # 1) Speech to text
    text_input = asr(audio_path)["text"]

    # 2) Translation
    load_translation_model(lang_pair)
    inputs = tokenizer(text_input, return_tensors="pt", padding=True)
    translated_ids = model.generate(**inputs)
    translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)

    # 3) Text to speech
    out_wav_path = text_to_speech(translated_text, ref_voice)
    return translated_text, out_wav_path

# --------------------------
# Gradio UI
# --------------------------
title = "πŸ—£ Voice-Cloned Translator (English ↔ Spanish)"
description = """
Upload a short **reference voice** (5–10s, clean speech works best) and speak into the microphone.
This Space uses **IndexTTS** for zero-shot voice cloning and **Hugging Face models** for translation.
"""

with gr.Blocks() as demo:
    gr.Markdown(f"# {title}\n{description}")

    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(sources=["microphone"], type="filepath", label="πŸŽ™ Speak")
            lang_dropdown = gr.Dropdown(list(language_models.keys()), label="🌍 Target Language", value="Spanish β†’ English")
            ref_voice_input = gr.Audio(sources=["upload"], type="filepath", label="🎧 Reference Voice (5–10s)")
            btn = gr.Button("Translate & Speak")

        with gr.Column():
            text_output = gr.Textbox(label="Translated Text")
            audio_output = gr.Audio(label="πŸ”Š Translated Audio", type="filepath")

    btn.click(
        fn=translate_with_voice,
        inputs=[audio_input, lang_dropdown, ref_voice_input],
        outputs=[text_output, audio_output]
    )

# Preload TTS on startup
def _startup():
    try:
        get_tts()
    except Exception as e:
        print("Warmup failed:", e)

if __name__ == "__main__":
    _startup()
    demo.launch(server_name="0.0.0.0", server_port=7860)