# app.py from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import FileResponse, HTMLResponse import torch import torchaudio import os from pathlib import Path from TTS.tts.models.xtts import Xtts from TTS.tts.configs.xtts_config import XttsConfig import gradio as gr import uvicorn # ------------------------ # Setup paths # ------------------------ MODEL_DIR = "my_model" # folder with config.json, vocab.json, model.pth OUTPUT_DIR = "outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------ # Load TTS model # ------------------------ config = XttsConfig() config.load_json(os.path.join(MODEL_DIR, "config.json")) model = Xtts.init_from_config(config) model.load_checkpoint( config, checkpoint_dir=MODEL_DIR, use_deepspeed=False, vocab_path=os.path.join(MODEL_DIR, "vocab.json") ) model.to(device) # ------------------------ # TTS function # ------------------------ def tts_arabic(text: str, audio_file: str) -> str: gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file]) out = model.inference( text=text, language="ar", gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=model.config.temperature, top_k=model.config.top_k, length_penalty=model.config.length_penalty, repetition_penalty=model.config.repetition_penalty, top_p=model.config.top_p, ) output_wav = os.path.join(OUTPUT_DIR, "output.wav") torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000) return output_wav # ------------------------ # FastAPI setup # ------------------------ app = FastAPI(title="EGTTS TTS API") @app.get("/", response_class=HTMLResponse) def index(): """Return simple HTML that links to Gradio UI""" return """
Swagger docs available at /docs
Try the Gradio interface at /gradio
""" @app.post("/tts/") async def tts_endpoint( text: str = Form(...), audio_file: UploadFile = File(...) ): # Save uploaded file file_path = os.path.join(OUTPUT_DIR, audio_file.filename) with open(file_path, "wb") as f: f.write(await audio_file.read()) output_wav = tts_arabic(text, file_path) return FileResponse(output_wav, media_type="audio/wav", filename="output.wav") # ------------------------ # Gradio interface # ------------------------ def gradio_fn(text, audio_file): return tts_arabic(text, audio_file.name) gradio_interface = gr.Interface( fn=gradio_fn, inputs=[ gr.Textbox(label="Arabic Text", placeholder="اكتب النص هنا..."), gr.File(label="Speaker Audio (.wav)") ], outputs=gr.Audio(label="Generated Speech"), live=True, title="EGTTS Arabic TTS", description="Generate Arabic speech from text using your fine-tuned EGTTS model." ) # Mount Gradio inside FastAPI @app.get("/gradio", response_class=HTMLResponse) def gradio_ui(): return gradio_interface.launch(inline=True, share=False, prevent_thread_lock=True).read() # ------------------------ # Run server # ------------------------ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)