Somali_TTS_API / app.py
HusseinBashir's picture
Create app.py
54fd70a verified
raw
history blame
2.64 kB
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
import torch
import numpy as np
import scipy.io.wavfile
from transformers import VitsModel, AutoTokenizer
import re
app = FastAPI()
# Load model and tokenizer
model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
number_words = {
0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex",
14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix",
17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal",
20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
100: "boqol", 1000: "kun"
}
def number_to_words(number):
number = int(number)
if number < 20:
return number_words[number]
elif number < 100:
tens, unit = divmod(number, 10)
return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "")
elif number < 1000:
hundreds, remainder = divmod(number, 100)
part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
if remainder:
part += " iyo " + number_to_words(remainder)
return part
elif number < 1000000:
thousands, remainder = divmod(number, 1000)
words = [number_to_words(thousands) + " kun" if thousands != 1 else "kun"]
if remainder:
words.append("iyo " + number_to_words(remainder))
return " ".join(words)
else:
return str(number)
def normalize_text(text):
numbers = re.findall(r'\d+', text)
for num in numbers:
text = text.replace(num, number_to_words(num))
text = text.replace("KH", "qa").replace("Z", "S")
text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
text = text.replace("ZamZam", "SamSam")
return text
@app.post("/tts")
async def tts(request: Request):
data = await request.json()
text = normalize_text(data["text"])
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
waveform = model(**inputs).waveform.squeeze().cpu().numpy()
filename = "output.wav"
scipy.io.wavfile.write(filename, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
return FileResponse(filename, media_type="audio/wav")