HusseinBashir commited on
Commit
91f8a04
·
verified ·
1 Parent(s): e00a04e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -1,20 +1,36 @@
1
- # app.py
 
 
 
 
2
 
3
- from fastapi import FastAPI, Request
4
- from fastapi.responses import JSONResponse
5
- from gradio_client import Client
6
 
7
- app = FastAPI()
8
- client = Client("HusseinBashir/Somali_tts")
9
 
10
- @app.post("/somali-tts/")
11
- async def somali_tts(request: Request):
12
- data = await request.json()
13
- text = data.get("text")
14
- if not text:
15
- return JSONResponse(content={"error": "No text provided"}, status_code=400)
16
- try:
17
- audio_url = client.predict(text, api_name="/predict")
18
- return JSONResponse(content={"audio_url": audio_url})
19
- except Exception as e:
20
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from TTS.models.vits import VitsModel
4
+ from transformers import AutoTokenizer
5
+ import torchaudio
6
 
7
+ # Load the model and tokenizer from Hugging Face
8
+ model = VitsModel.from_pretrained("HusseinBashir/codad_tijaabo")
9
+ tokenizer = AutoTokenizer.from_pretrained("HusseinBashir/codad_tijaabo")
10
 
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device).eval()
13
 
14
+ def tts_infer(text):
15
+ inputs = tokenizer(text, return_tensors="pt")
16
+ input_ids = inputs.input_ids.to(device)
17
+
18
+ with torch.no_grad():
19
+ output = model(input_ids)
20
+ waveform = output["waveform"]
21
+
22
+ # Save or return audio
23
+ sample_rate = 22050 # VITS typically uses 22.05kHz
24
+ torchaudio.save("output.wav", waveform.cpu(), sample_rate)
25
+ return "output.wav"
26
+
27
+ # Create Gradio UI
28
+ interface = gr.Interface(
29
+ fn=tts_infer,
30
+ inputs=gr.Textbox(label="Geli qoraalka aad rabto in cod laga dhigo"),
31
+ outputs=gr.Audio(label="Codka la sameeyey"),
32
+ title="Codad Tijaabo TTS",
33
+ description="Ku qor qoraal Soomaali ah si aad cod u maqasho.",
34
+ )
35
+
36
+ interface.launch()