Spaces:
Build error
Build error
| # app.py | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| import torch | |
| import torchaudio | |
| import os | |
| from pathlib import Path | |
| from TTS.tts.models.xtts import Xtts | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| import gradio as gr | |
| import uvicorn | |
| # ------------------------ | |
| # Setup paths | |
| # ------------------------ | |
| MODEL_DIR = "my_model" # folder with config.json, vocab.json, model.pth | |
| OUTPUT_DIR = "outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ------------------------ | |
| # Load TTS model | |
| # ------------------------ | |
| config = XttsConfig() | |
| config.load_json(os.path.join(MODEL_DIR, "config.json")) | |
| model = Xtts.init_from_config(config) | |
| model.load_checkpoint( | |
| config, | |
| checkpoint_dir=MODEL_DIR, | |
| use_deepspeed=False, | |
| vocab_path=os.path.join(MODEL_DIR, "vocab.json") | |
| ) | |
| model.to(device) | |
| # ------------------------ | |
| # TTS function | |
| # ------------------------ | |
| def tts_arabic(text: str, audio_file: str) -> str: | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file]) | |
| out = model.inference( | |
| text=text, | |
| language="ar", | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| temperature=model.config.temperature, | |
| top_k=model.config.top_k, | |
| length_penalty=model.config.length_penalty, | |
| repetition_penalty=model.config.repetition_penalty, | |
| top_p=model.config.top_p, | |
| ) | |
| output_wav = os.path.join(OUTPUT_DIR, "output.wav") | |
| torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
| return output_wav | |
| # ------------------------ | |
| # FastAPI setup | |
| # ------------------------ | |
| app = FastAPI(title="EGTTS TTS API") | |
| def index(): | |
| """Return simple HTML that links to Gradio UI""" | |
| return """ | |
| <h2>Welcome to EGTTS TTS API</h2> | |
| <p>Swagger docs available at <a href="/docs">/docs</a></p> | |
| <p>Try the Gradio interface at <a href="/gradio">/gradio</a></p> | |
| """ | |
| async def tts_endpoint( | |
| text: str = Form(...), | |
| audio_file: UploadFile = File(...) | |
| ): | |
| # Save uploaded file | |
| file_path = os.path.join(OUTPUT_DIR, audio_file.filename) | |
| with open(file_path, "wb") as f: | |
| f.write(await audio_file.read()) | |
| output_wav = tts_arabic(text, file_path) | |
| return FileResponse(output_wav, media_type="audio/wav", filename="output.wav") | |
| # ------------------------ | |
| # Gradio interface | |
| # ------------------------ | |
| def gradio_fn(text, audio_file): | |
| return tts_arabic(text, audio_file.name) | |
| gradio_interface = gr.Interface( | |
| fn=gradio_fn, | |
| inputs=[ | |
| gr.Textbox(label="Arabic Text", placeholder="اكتب النص هنا..."), | |
| gr.File(label="Speaker Audio (.wav)") | |
| ], | |
| outputs=gr.Audio(label="Generated Speech"), | |
| live=True, | |
| title="EGTTS Arabic TTS", | |
| description="Generate Arabic speech from text using your fine-tuned EGTTS model." | |
| ) | |
| # Mount Gradio inside FastAPI | |
| def gradio_ui(): | |
| return gradio_interface.launch(inline=True, share=False, prevent_thread_lock=True).read() | |
| # ------------------------ | |
| # Run server | |
| # ------------------------ | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |