File size: 3,375 Bytes
99c5e44
6bd99af
 
99c5e44
 
 
6bd99af
 
 
 
 
99c5e44
 
6bd99af
99c5e44
6bd99af
99c5e44
 
 
 
 
6bd99af
 
 
99c5e44
 
 
 
 
 
 
 
 
 
 
 
 
6bd99af
99c5e44
6bd99af
 
99c5e44
 
 
 
 
 
 
 
 
6bd99af
99c5e44
 
 
 
 
 
6bd99af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99c5e44
6bd99af
 
 
 
 
99c5e44
 
 
 
 
 
 
 
 
 
6bd99af
 
 
 
 
99c5e44
6bd99af
99c5e44
 
6bd99af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# app.py
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse, HTMLResponse
import torch
import torchaudio
import os
from pathlib import Path
from TTS.tts.models.xtts import Xtts
from TTS.tts.configs.xtts_config import XttsConfig
import gradio as gr
import uvicorn

# ------------------------
# Setup paths
# ------------------------
MODEL_DIR = "my_model"          # folder with config.json, vocab.json, model.pth
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------
# Load TTS model
# ------------------------
config = XttsConfig()
config.load_json(os.path.join(MODEL_DIR, "config.json"))

model = Xtts.init_from_config(config)
model.load_checkpoint(
    config,
    checkpoint_dir=MODEL_DIR,
    use_deepspeed=False,
    vocab_path=os.path.join(MODEL_DIR, "vocab.json")
)
model.to(device)

# ------------------------
# TTS function
# ------------------------
def tts_arabic(text: str, audio_file: str) -> str:
    gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
    out = model.inference(
        text=text,
        language="ar",
        gpt_cond_latent=gpt_cond_latent,
        speaker_embedding=speaker_embedding,
        temperature=model.config.temperature,
        top_k=model.config.top_k,
        length_penalty=model.config.length_penalty,
        repetition_penalty=model.config.repetition_penalty,
        top_p=model.config.top_p,
    )
    output_wav = os.path.join(OUTPUT_DIR, "output.wav")
    torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000)
    return output_wav

# ------------------------
# FastAPI setup
# ------------------------
app = FastAPI(title="EGTTS TTS API")

@app.get("/", response_class=HTMLResponse)
def index():
    """Return simple HTML that links to Gradio UI"""
    return """
    <h2>Welcome to EGTTS TTS API</h2>
    <p>Swagger docs available at <a href="/docs">/docs</a></p>
    <p>Try the Gradio interface at <a href="/gradio">/gradio</a></p>
    """

@app.post("/tts/")
async def tts_endpoint(
    text: str = Form(...),
    audio_file: UploadFile = File(...)
):
    # Save uploaded file
    file_path = os.path.join(OUTPUT_DIR, audio_file.filename)
    with open(file_path, "wb") as f:
        f.write(await audio_file.read())

    output_wav = tts_arabic(text, file_path)
    return FileResponse(output_wav, media_type="audio/wav", filename="output.wav")

# ------------------------
# Gradio interface
# ------------------------
def gradio_fn(text, audio_file):
    return tts_arabic(text, audio_file.name)

gradio_interface = gr.Interface(
    fn=gradio_fn,
    inputs=[
        gr.Textbox(label="Arabic Text", placeholder="اكتب النص هنا..."),
        gr.File(label="Speaker Audio (.wav)")
    ],
    outputs=gr.Audio(label="Generated Speech"),
    live=True,
    title="EGTTS Arabic TTS",
    description="Generate Arabic speech from text using your fine-tuned EGTTS model."
)

# Mount Gradio inside FastAPI
@app.get("/gradio", response_class=HTMLResponse)
def gradio_ui():
    return gradio_interface.launch(inline=True, share=False, prevent_thread_lock=True).read()

# ------------------------
# Run server
# ------------------------
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)