Spaces:
Running
Running
File size: 3,309 Bytes
5201795 48ee974 9f334f8 7ee5ee5 267f974 48ee974 267f974 48ee974 af0c43b 16efb96 48ee974 af0c43b 267f974 48ee974 267f974 48ee974 267f974 12dfa83 267f974 48ee974 267f974 a44a8cb 267f974 5201795 267f974 48ee974 267f974 5201795 48ee974 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import gradio as gr
import torch
import os
from TTS.api import TTS
from huggingface_hub import hf_hub_download
# --- ROMANIZER IMPORT ---
try:
from romanizer import sinhala_to_roman
except ImportError:
def sinhala_to_roman(text): return text
# --- CONSOLIDATED MODEL LOADING ---
def load_standard_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
return TTS(model_path=model_path, config_path=config_path, gpu=False)
def load_eng_model_with_surgery():
repo_id = "E-motionAssistant/text-to-speech-VITS-english"
model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
checkpoint = torch.load(model_path, map_location="cpu")
raw_weights = checkpoint['model']['text_encoder.emb.weight']
if raw_weights.shape[0] == 137:
checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :]
fixed_model_path = "fixed_eng_model.pth"
torch.save(checkpoint, fixed_model_path)
return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False)
return TTS(model_path=model_path, config_path=config_path, gpu=False)
# --- INITIALIZATION ---
print("Loading all models... this may take a moment.")
models = {
"sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"),
"tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"),
"english": load_eng_model_with_surgery()
}
# --- SPECIFIC ENDPOINT FUNCTIONS ---
def tts_english(text):
output = "english_out.wav"
models["english"].tts_to_file(text=text, file_path=output)
return output
def tts_sinhala(text):
processed = sinhala_to_roman(text)
output = "sinhala_out.wav"
models["sinhala"].tts_to_file(text=processed, file_path=output)
return output
def tts_tamil(text):
output = "tamil_out.wav"
models["tamil"].tts_to_file(text=text, file_path=output)
return output
# --- GRADIO UI WITH TABS ---
with gr.Blocks(title="Multilingual TTS API") as demo:
gr.Markdown("# Trilingual TTS System")
gr.Markdown("Choose a tab below to use a specific language endpoint.")
with gr.Tab("English"):
input_eng = gr.Textbox(label="English Text")
output_eng = gr.Audio(label="English Audio", type="filepath")
btn_eng = gr.Button("Synthesize English")
# api_name creates a specific endpoint: /api/predict/english_tts
btn_eng.click(tts_english, inputs=input_eng, outputs=output_eng, api_name="english_tts")
with gr.Tab("Sinhala"):
input_sin = gr.Textbox(label="Sinhala Text (Input Unicode)")
output_sin = gr.Audio(label="Sinhala Audio", type="filepath")
btn_sin = gr.Button("Synthesize Sinhala")
btn_sin.click(tts_sinhala, inputs=input_sin, outputs=output_sin, api_name="sinhala_tts")
with gr.Tab("Tamil"):
input_tam = gr.Textbox(label="Tamil Text")
output_tam = gr.Audio(label="Tamil Audio", type="filepath")
btn_tam = gr.Button("Synthesize Tamil")
btn_tam.click(tts_tamil, inputs=input_tam, outputs=output_tam, api_name="tamil_tts")
if __name__ == "__main__":
demo.launch() |