SpireLab's picture
Update API_Main.py
8017cce verified
raw
history blame
2.85 kB
import os
import wave
import torch
import string
import random
import uvicorn
import numpy as np
from io import BytesIO
from TTS.api import TTS
from fastapi import FastAPI, UploadFile
from scipy.io.wavfile import write
from fastapi.responses import Response, JSONResponse
os.makedirs("temp/", exist_ok = True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}", flush = True)
MODEL_PATH = "models/best_model.pth"
CONFIG_PATH = "models/config.json"
print(f"Loading model", flush = True)
tts = TTS(
model_path=MODEL_PATH,
config_path=CONFIG_PATH,
progress_bar=False,
).to(device)
sample_rate = 22050
languageCODE = {
"bhojpuri": "bho",
"bengali": "bn",
"english": "en",
"gujarati": "gu",
"hindi": "hi",
"chhattisgarhi": "hne",
"kannada": "kn",
"magahi": "mag",
"maithili": "mai",
"marathi": "mr",
"telugu": "te"
}
app = FastAPI()
@app.get("/")
def Is_alive():
return {"message" : "Server is Live"}
@app.get("/Get_Inference")
async def Inference(text : str, lang : str, speaker_wav : UploadFile):
if not text or not lang or not speaker_wav:
return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
lan = lang.lower()
if not speaker_wav:
return JSONResponse({"comment" : "Speaker file not provided."}, status_code = 422)
if lan not in languageCODE:
if lan not in languageCODE.values():
return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
speaker_wav_filename = "temp/" + random_string + "_" + speaker_wav.filename
with open(speaker_wav_filename , "wb") as wavFile:
wavFile.write(speaker_wav.file.read())
try:
with wave.open(speaker_wav_filename, "r") as temper:
pass
except:
return JSONResponse({"comment" : "Audio file format not supported."}, status_code = 422)
wav = np.array(tts.tts(text=text, speaker_wav = speaker_wav_filename, language = languageCODE[lan]))
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
wav_buffer = BytesIO()
write(wav_buffer, sample_rate, wav_norm)
wav_buffer.seek(0)
wav_buffer.name = lang + "_" + speaker_wav.filename + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
os.remove(speaker_wav_filename)
return Response(wav_buffer.read())
def start_server():
print('Starting Server...')
uvicorn.run(
"API_Main:app",
host = "0.0.0.0",
port = 8080,
log_level="debug",
reload=False,
)
if __name__ == "__main__":
start_server()