MariaKaiser commited on
Commit
6335b4b
·
verified ·
1 Parent(s): a22e089

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, UploadFile, File, Form
3
+ from fastapi.responses import FileResponse, HTMLResponse
4
+ import torch
5
+ import torchaudio
6
+ import os
7
+ from pathlib import Path
8
+ from TTS.tts.models.xtts import Xtts
9
+ from TTS.tts.configs.xtts_config import XttsConfig
10
+ import gradio as gr
11
+ import uvicorn
12
+
13
+ # ------------------------
14
+ # Setup paths
15
+ # ------------------------
16
+ MODEL_DIR = "my_model" # folder with config.json, vocab.json, model.pth
17
+ OUTPUT_DIR = "outputs"
18
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # ------------------------
23
+ # Load TTS model
24
+ # ------------------------
25
+ config = XttsConfig()
26
+ config.load_json(os.path.join(MODEL_DIR, "config.json"))
27
+
28
+ model = Xtts.init_from_config(config)
29
+ model.load_checkpoint(
30
+ config,
31
+ checkpoint_dir=MODEL_DIR,
32
+ use_deepspeed=False,
33
+ vocab_path=os.path.join(MODEL_DIR, "vocab.json")
34
+ )
35
+ model.to(device)
36
+
37
+ # ------------------------
38
+ # TTS function
39
+ # ------------------------
40
+ def tts_arabic(text: str, audio_file: str) -> str:
41
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
42
+ out = model.inference(
43
+ text=text,
44
+ language="ar",
45
+ gpt_cond_latent=gpt_cond_latent,
46
+ speaker_embedding=speaker_embedding,
47
+ temperature=model.config.temperature,
48
+ top_k=model.config.top_k,
49
+ length_penalty=model.config.length_penalty,
50
+ repetition_penalty=model.config.repetition_penalty,
51
+ top_p=model.config.top_p,
52
+ )
53
+ output_wav = os.path.join(OUTPUT_DIR, "output.wav")
54
+ torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000)
55
+ return output_wav
56
+
57
+ # ------------------------
58
+ # FastAPI setup
59
+ # ------------------------
60
+ app = FastAPI(title="EGTTS TTS API")
61
+
62
+ @app.get("/", response_class=HTMLResponse)
63
+ def index():
64
+ """Return simple HTML that links to Gradio UI"""
65
+ return """
66
+ <h2>Welcome to EGTTS TTS API</h2>
67
+ <p>Swagger docs available at <a href="/docs">/docs</a></p>
68
+ <p>Try the Gradio interface at <a href="/gradio">/gradio</a></p>
69
+ """
70
+
71
+ @app.post("/tts/")
72
+ async def tts_endpoint(
73
+ text: str = Form(...),
74
+ audio_file: UploadFile = File(...)
75
+ ):
76
+ # Save uploaded file
77
+ file_path = os.path.join(OUTPUT_DIR, audio_file.filename)
78
+ with open(file_path, "wb") as f:
79
+ f.write(await audio_file.read())
80
+
81
+ output_wav = tts_arabic(text, file_path)
82
+ return FileResponse(output_wav, media_type="audio/wav", filename="output.wav")
83
+
84
+ # ------------------------
85
+ # Gradio interface
86
+ # ------------------------
87
+ def gradio_fn(text, audio_file):
88
+ return tts_arabic(text, audio_file.name)
89
+
90
+ gradio_interface = gr.Interface(
91
+ fn=gradio_fn,
92
+ inputs=[
93
+ gr.Textbox(label="Arabic Text", placeholder="اكتب النص هنا..."),
94
+ gr.File(label="Speaker Audio (.wav)")
95
+ ],
96
+ outputs=gr.Audio(label="Generated Speech"),
97
+ live=True,
98
+ title="EGTTS Arabic TTS",
99
+ description="Generate Arabic speech from text using your fine-tuned EGTTS model."
100
+ )
101
+
102
+ # Mount Gradio inside FastAPI
103
+ @app.get("/gradio", response_class=HTMLResponse)
104
+ def gradio_ui():
105
+ return gradio_interface.launch(inline=True, share=False, prevent_thread_lock=True).read()
106
+
107
+ # ------------------------
108
+ # Run server
109
+ # ------------------------
110
+ if __name__ == "__main__":
111
+ uvicorn.run(app, host="0.0.0.0", port=7860)