import gradio as gr import torch import os from TTS.api import TTS from huggingface_hub import hf_hub_download # --- ROMANIZER IMPORT --- try: from romanizer import sinhala_to_roman except ImportError: def sinhala_to_roman(text): return text # --- CONSOLIDATED MODEL LOADING --- def load_standard_model(repo_id): model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") return TTS(model_path=model_path, config_path=config_path, gpu=False) def load_eng_model_with_surgery(): repo_id = "E-motionAssistant/text-to-speech-VITS-english" model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") checkpoint = torch.load(model_path, map_location="cpu") raw_weights = checkpoint['model']['text_encoder.emb.weight'] if raw_weights.shape[0] == 137: checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :] fixed_model_path = "fixed_eng_model.pth" torch.save(checkpoint, fixed_model_path) return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False) return TTS(model_path=model_path, config_path=config_path, gpu=False) # --- INITIALIZATION --- print("Loading all models... this may take a moment.") models = { "sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"), "tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"), "english": load_eng_model_with_surgery() } # --- SPECIFIC ENDPOINT FUNCTIONS --- def tts_english(text): output = "english_out.wav" models["english"].tts_to_file(text=text, file_path=output) return output def tts_sinhala(text): processed = sinhala_to_roman(text) output = "sinhala_out.wav" models["sinhala"].tts_to_file(text=processed, file_path=output) return output def tts_tamil(text): output = "tamil_out.wav" models["tamil"].tts_to_file(text=text, file_path=output) return output # --- GRADIO UI WITH TABS --- with gr.Blocks(title="Multilingual TTS API") as demo: gr.Markdown("# Trilingual TTS System") gr.Markdown("Choose a tab below to use a specific language endpoint.") with gr.Tab("English"): input_eng = gr.Textbox(label="English Text") output_eng = gr.Audio(label="English Audio", type="filepath") btn_eng = gr.Button("Synthesize English") # api_name creates a specific endpoint: /api/predict/english_tts btn_eng.click(tts_english, inputs=input_eng, outputs=output_eng, api_name="english_tts") with gr.Tab("Sinhala"): input_sin = gr.Textbox(label="Sinhala Text (Input Unicode)") output_sin = gr.Audio(label="Sinhala Audio", type="filepath") btn_sin = gr.Button("Synthesize Sinhala") btn_sin.click(tts_sinhala, inputs=input_sin, outputs=output_sin, api_name="sinhala_tts") with gr.Tab("Tamil"): input_tam = gr.Textbox(label="Tamil Text") output_tam = gr.Audio(label="Tamil Audio", type="filepath") btn_tam = gr.Button("Synthesize Tamil") btn_tam.click(tts_tamil, inputs=input_tam, outputs=output_tam, api_name="tamil_tts") if __name__ == "__main__": demo.launch()