MariaKaiser commited on
Commit
d8cdd10
·
verified ·
1 Parent(s): de2c11f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -58
app.py CHANGED
@@ -1,27 +1,18 @@
1
- # app.py
2
  from fastapi import FastAPI, UploadFile, File, Form
3
- from fastapi.responses import FileResponse, HTMLResponse
4
  import torch
5
  import torchaudio
6
  import os
7
- from pathlib import Path
8
  from TTS.tts.models.xtts import Xtts
9
  from TTS.tts.configs.xtts_config import XttsConfig
10
- import gradio as gr
11
- import uvicorn
12
 
13
- # ------------------------
14
- # Setup paths
15
- # ------------------------
16
- MODEL_DIR = "my_model" # folder with config.json, vocab.json, model.pth
17
  OUTPUT_DIR = "outputs"
18
  os.makedirs(OUTPUT_DIR, exist_ok=True)
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # ------------------------
23
- # Load TTS model
24
- # ------------------------
25
  config = XttsConfig()
26
  config.load_json(os.path.join(MODEL_DIR, "config.json"))
27
 
@@ -34,9 +25,6 @@ model.load_checkpoint(
34
  )
35
  model.to(device)
36
 
37
- # ------------------------
38
- # TTS function
39
- # ------------------------
40
  def tts_arabic(text: str, audio_file: str) -> str:
41
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
42
  out = model.inference(
@@ -54,58 +42,20 @@ def tts_arabic(text: str, audio_file: str) -> str:
54
  torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000)
55
  return output_wav
56
 
57
- # ------------------------
58
- # FastAPI setup
59
- # ------------------------
60
- app = FastAPI(title="EGTTS TTS API")
61
 
62
- @app.get("/", response_class=HTMLResponse)
63
- def index():
64
- """Return simple HTML that links to Gradio UI"""
65
- return """
66
- <h2>Welcome to EGTTS TTS API</h2>
67
- <p>Swagger docs available at <a href="/docs">/docs</a></p>
68
- <p>Try the Gradio interface at <a href="/gradio">/gradio</a></p>
69
- """
70
 
71
  @app.post("/tts/")
72
  async def tts_endpoint(
73
  text: str = Form(...),
74
  audio_file: UploadFile = File(...)
75
  ):
76
- # Save uploaded file
77
  file_path = os.path.join(OUTPUT_DIR, audio_file.filename)
78
  with open(file_path, "wb") as f:
79
  f.write(await audio_file.read())
80
 
81
  output_wav = tts_arabic(text, file_path)
82
- return FileResponse(output_wav, media_type="audio/wav", filename="output.wav")
83
-
84
- # ------------------------
85
- # Gradio interface
86
- # ------------------------
87
- def gradio_fn(text, audio_file):
88
- return tts_arabic(text, audio_file.name)
89
-
90
- gradio_interface = gr.Interface(
91
- fn=gradio_fn,
92
- inputs=[
93
- gr.Textbox(label="Arabic Text", placeholder="اكتب النص هنا..."),
94
- gr.File(label="Speaker Audio (.wav)")
95
- ],
96
- outputs=gr.Audio(label="Generated Speech"),
97
- live=True,
98
- title="EGTTS Arabic TTS",
99
- description="Generate Arabic speech from text using your fine-tuned EGTTS model."
100
- )
101
-
102
- # Mount Gradio inside FastAPI
103
- @app.get("/gradio", response_class=HTMLResponse)
104
- def gradio_ui():
105
- return gradio_interface.launch(inline=True, share=False, prevent_thread_lock=True).read()
106
-
107
- # ------------------------
108
- # Run server
109
- # ------------------------
110
- if __name__ == "__main__":
111
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  from fastapi import FastAPI, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
  import torch
4
  import torchaudio
5
  import os
 
6
  from TTS.tts.models.xtts import Xtts
7
  from TTS.tts.configs.xtts_config import XttsConfig
 
 
8
 
9
+ MODEL_DIR = "my_model"
 
 
 
10
  OUTPUT_DIR = "outputs"
11
  os.makedirs(OUTPUT_DIR, exist_ok=True)
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Load model
 
 
16
  config = XttsConfig()
17
  config.load_json(os.path.join(MODEL_DIR, "config.json"))
18
 
 
25
  )
26
  model.to(device)
27
 
 
 
 
28
  def tts_arabic(text: str, audio_file: str) -> str:
29
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
30
  out = model.inference(
 
42
  torchaudio.save(output_wav, torch.tensor(out["wav"]).unsqueeze(0), 24000)
43
  return output_wav
44
 
45
+ app = FastAPI(title="EGTTS Arabic TTS API")
 
 
 
46
 
47
+ @app.get("/")
48
+ def root():
49
+ return {"message": "Welcome! Visit /docs for Swagger UI."}
 
 
 
 
 
50
 
51
  @app.post("/tts/")
52
  async def tts_endpoint(
53
  text: str = Form(...),
54
  audio_file: UploadFile = File(...)
55
  ):
 
56
  file_path = os.path.join(OUTPUT_DIR, audio_file.filename)
57
  with open(file_path, "wb") as f:
58
  f.write(await audio_file.read())
59
 
60
  output_wav = tts_arabic(text, file_path)
61
+ return FileResponse(output_wav, media_type="audio/wav", filename="output.wav")