|
|
import os |
|
|
import wave |
|
|
import torch |
|
|
import string |
|
|
import random |
|
|
import uvicorn |
|
|
import numpy as np |
|
|
from io import BytesIO |
|
|
from TTS.api import TTS |
|
|
from fastapi import FastAPI, UploadFile |
|
|
from scipy.io.wavfile import write |
|
|
from fastapi.responses import Response, JSONResponse |
|
|
|
|
|
os.makedirs("temp/", exist_ok = True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}", flush = True) |
|
|
|
|
|
MODEL_PATH = "models/best_model.pth" |
|
|
CONFIG_PATH = "models/config.json" |
|
|
|
|
|
print(f"Loading model", flush = True) |
|
|
tts = TTS( |
|
|
model_path=MODEL_PATH, |
|
|
config_path=CONFIG_PATH, |
|
|
progress_bar=False, |
|
|
).to(device) |
|
|
|
|
|
sample_rate = 22050 |
|
|
|
|
|
languageCODE = { |
|
|
"bhojpuri": "bho", |
|
|
"bengali": "bn", |
|
|
"english": "en", |
|
|
"gujarati": "gu", |
|
|
"hindi": "hi", |
|
|
"chhattisgarhi": "hne", |
|
|
"kannada": "kn", |
|
|
"magahi": "mag", |
|
|
"maithili": "mai", |
|
|
"marathi": "mr", |
|
|
"telugu": "te" |
|
|
} |
|
|
|
|
|
app = FastAPI() |
|
|
@app.get("/") |
|
|
def Is_alive(): |
|
|
return {"message" : "Server is Live"} |
|
|
|
|
|
@app.get("/Get_Inference") |
|
|
async def Inference(text : str, lang : str, speaker_wav : UploadFile): |
|
|
|
|
|
if not text or not lang or not speaker_wav: |
|
|
return JSONResponse({"comment" : "Missing Field."}, status_code = 422) |
|
|
|
|
|
lan = lang.lower() |
|
|
|
|
|
if not speaker_wav: |
|
|
return JSONResponse({"comment" : "Speaker file not provided."}, status_code = 422) |
|
|
|
|
|
if lan not in languageCODE: |
|
|
if lan not in languageCODE.values(): |
|
|
return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422) |
|
|
|
|
|
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) |
|
|
|
|
|
speaker_wav_filename = "temp/" + random_string + "_" + speaker_wav.filename |
|
|
|
|
|
with open(speaker_wav_filename , "wb") as wavFile: |
|
|
wavFile.write(speaker_wav.file.read()) |
|
|
|
|
|
try: |
|
|
with wave.open(speaker_wav_filename, "r") as temper: |
|
|
pass |
|
|
except: |
|
|
return JSONResponse({"comment" : "Audio file format not supported."}, status_code = 422) |
|
|
|
|
|
wav = np.array(tts.tts(text=text, speaker_wav = speaker_wav_filename, language = languageCODE[lan])) |
|
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) |
|
|
wav_norm = wav_norm.astype(np.int16) |
|
|
|
|
|
wav_buffer = BytesIO() |
|
|
write(wav_buffer, sample_rate, wav_norm) |
|
|
wav_buffer.seek(0) |
|
|
wav_buffer.name = lang + "_" + speaker_wav.filename + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav" |
|
|
|
|
|
os.remove(speaker_wav_filename) |
|
|
|
|
|
return Response(wav_buffer.read()) |
|
|
|
|
|
|
|
|
def start_server(): |
|
|
print('Starting Server...') |
|
|
|
|
|
uvicorn.run( |
|
|
"API_Main:app", |
|
|
host = "0.0.0.0", |
|
|
port = 8080, |
|
|
log_level="debug", |
|
|
reload=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
start_server() |