Space3 / app.py
amasha03's picture
Update app.py
267f974 verified
import gradio as gr
import torch
import os
from TTS.api import TTS
from huggingface_hub import hf_hub_download
# --- ROMANIZER IMPORT ---
try:
from romanizer import sinhala_to_roman
except ImportError:
def sinhala_to_roman(text): return text
# --- CONSOLIDATED MODEL LOADING ---
def load_standard_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
return TTS(model_path=model_path, config_path=config_path, gpu=False)
def load_eng_model_with_surgery():
repo_id = "E-motionAssistant/text-to-speech-VITS-english"
model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
checkpoint = torch.load(model_path, map_location="cpu")
raw_weights = checkpoint['model']['text_encoder.emb.weight']
if raw_weights.shape[0] == 137:
checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :]
fixed_model_path = "fixed_eng_model.pth"
torch.save(checkpoint, fixed_model_path)
return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False)
return TTS(model_path=model_path, config_path=config_path, gpu=False)
# --- INITIALIZATION ---
print("Loading all models... this may take a moment.")
models = {
"sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"),
"tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"),
"english": load_eng_model_with_surgery()
}
# --- SPECIFIC ENDPOINT FUNCTIONS ---
def tts_english(text):
output = "english_out.wav"
models["english"].tts_to_file(text=text, file_path=output)
return output
def tts_sinhala(text):
processed = sinhala_to_roman(text)
output = "sinhala_out.wav"
models["sinhala"].tts_to_file(text=processed, file_path=output)
return output
def tts_tamil(text):
output = "tamil_out.wav"
models["tamil"].tts_to_file(text=text, file_path=output)
return output
# --- GRADIO UI WITH TABS ---
with gr.Blocks(title="Multilingual TTS API") as demo:
gr.Markdown("# Trilingual TTS System")
gr.Markdown("Choose a tab below to use a specific language endpoint.")
with gr.Tab("English"):
input_eng = gr.Textbox(label="English Text")
output_eng = gr.Audio(label="English Audio", type="filepath")
btn_eng = gr.Button("Synthesize English")
# api_name creates a specific endpoint: /api/predict/english_tts
btn_eng.click(tts_english, inputs=input_eng, outputs=output_eng, api_name="english_tts")
with gr.Tab("Sinhala"):
input_sin = gr.Textbox(label="Sinhala Text (Input Unicode)")
output_sin = gr.Audio(label="Sinhala Audio", type="filepath")
btn_sin = gr.Button("Synthesize Sinhala")
btn_sin.click(tts_sinhala, inputs=input_sin, outputs=output_sin, api_name="sinhala_tts")
with gr.Tab("Tamil"):
input_tam = gr.Textbox(label="Tamil Text")
output_tam = gr.Audio(label="Tamil Audio", type="filepath")
btn_tam = gr.Button("Synthesize Tamil")
btn_tam.click(tts_tamil, inputs=input_tam, outputs=output_tam, api_name="tamil_tts")
if __name__ == "__main__":
demo.launch()