File size: 3,309 Bytes
5201795
48ee974
 
9f334f8
 
7ee5ee5
267f974
48ee974
 
 
 
 
267f974
48ee974
af0c43b
 
16efb96
 
48ee974
 
 
 
 
 
 
 
 
 
 
 
af0c43b
267f974
 
48ee974
267f974
 
 
48ee974
 
267f974
 
 
 
 
12dfa83
267f974
 
 
 
 
48ee974
267f974
 
 
 
a44a8cb
267f974
 
 
 
 
 
 
 
 
 
 
5201795
267f974
 
 
 
 
48ee974
267f974
 
 
 
 
5201795
48ee974
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
import os
from TTS.api import TTS
from huggingface_hub import hf_hub_download

# --- ROMANIZER IMPORT ---
try:
    from romanizer import sinhala_to_roman
except ImportError:
    def sinhala_to_roman(text): return text

# --- CONSOLIDATED MODEL LOADING ---
def load_standard_model(repo_id):
    model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    return TTS(model_path=model_path, config_path=config_path, gpu=False)

def load_eng_model_with_surgery():
    repo_id = "E-motionAssistant/text-to-speech-VITS-english"
    model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    checkpoint = torch.load(model_path, map_location="cpu")
    raw_weights = checkpoint['model']['text_encoder.emb.weight']
    if raw_weights.shape[0] == 137:
        checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :]
        fixed_model_path = "fixed_eng_model.pth"
        torch.save(checkpoint, fixed_model_path)
        return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False)
    return TTS(model_path=model_path, config_path=config_path, gpu=False)

# --- INITIALIZATION ---
print("Loading all models... this may take a moment.")
models = {
    "sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"),
    "tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"),
    "english": load_eng_model_with_surgery()
}

# --- SPECIFIC ENDPOINT FUNCTIONS ---
def tts_english(text):
    output = "english_out.wav"
    models["english"].tts_to_file(text=text, file_path=output)
    return output

def tts_sinhala(text):
    processed = sinhala_to_roman(text)
    output = "sinhala_out.wav"
    models["sinhala"].tts_to_file(text=processed, file_path=output)
    return output

def tts_tamil(text):
    output = "tamil_out.wav"
    models["tamil"].tts_to_file(text=text, file_path=output)
    return output

# --- GRADIO UI WITH TABS ---
with gr.Blocks(title="Multilingual TTS API") as demo:
    gr.Markdown("# Trilingual TTS System")
    gr.Markdown("Choose a tab below to use a specific language endpoint.")
    
    with gr.Tab("English"):
        input_eng = gr.Textbox(label="English Text")
        output_eng = gr.Audio(label="English Audio", type="filepath")
        btn_eng = gr.Button("Synthesize English")
        # api_name creates a specific endpoint: /api/predict/english_tts
        btn_eng.click(tts_english, inputs=input_eng, outputs=output_eng, api_name="english_tts")

    with gr.Tab("Sinhala"):
        input_sin = gr.Textbox(label="Sinhala Text (Input Unicode)")
        output_sin = gr.Audio(label="Sinhala Audio", type="filepath")
        btn_sin = gr.Button("Synthesize Sinhala")
        btn_sin.click(tts_sinhala, inputs=input_sin, outputs=output_sin, api_name="sinhala_tts")

    with gr.Tab("Tamil"):
        input_tam = gr.Textbox(label="Tamil Text")
        output_tam = gr.Audio(label="Tamil Audio", type="filepath")
        btn_tam = gr.Button("Synthesize Tamil")
        btn_tam.click(tts_tamil, inputs=input_tam, outputs=output_tam, api_name="tamil_tts")

if __name__ == "__main__":
    demo.launch()