SpireLab's picture
Initial_commit
f15e280 verified
raw
history blame
3.55 kB
import torch
import string
import random
import uvicorn
import numpy as np
from io import BytesIO
from TTS.api import TTS
from fastapi import FastAPI
from scipy.io.wavfile import write
from fastapi.responses import Response, JSONResponse
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
MODEL_PATH = "models/best_model.pth"
CONFIG_PATH = "models/config.json"
print(f"Loading model")
tts = TTS(
model_path=MODEL_PATH,
config_path=CONFIG_PATH,
progress_bar=False,
).to(device)
sample_rate = 22050
ref_path = {
"chhattisgarhi_male" : "reference_audio/chhattisgarhi_male.wav",
"chhattisgarhi_female" : "reference_audio/chhattisgarhi_female.wav",
"kannada_male" : "reference_audio/kannada_male.wav",
"kannada_female" : "reference_audio/kannada_female.wav",
"maithili_male" : "reference_audio/maithili_male.wav",
"maithili_female" : "reference_audio/maithili_female.wav",
"telugu_male" : "reference_audio/telugu_male.wav",
"telugu_female" : "reference_audio/telugu_female.wav",
"bengali_male" : "reference_audio/bengali_male.wav",
"bengali_female" : "reference_audio/bengali_female.wav",
"bhojpuri_male" : "reference_audio/bhojpuri_male.wav",
"bhojpuri_female" : "reference_audio/bhojpuri_female.wav",
"marathi_female" : "reference_audio/marathi_female.wav",
"marathi_male" : "reference_audio/marathi_male.wav",
"gujarati_male" : "reference_audio/gujarati_male.wav",
"gujarati_female" : "reference_audio/gujarati_female.wav",
"hindi_male" : "reference_audio/hindi_male.wav",
"hindi_female" : "reference_audio/hindi_female.wav",
"magahi_female" : "reference_audio/magahi_female.wav",
"magahi_male" : "reference_audio/magahi_male.wav",
"english_female" : "reference_audio/english_female.wav",
"english_male" : "reference_audio/english_male.wav",
}
languageCODE = {
"bhojpuri": "bho",
"bengali": "bn",
"english": "en",
"gujarati": "gu",
"hindi": "hi",
"chhattisgarhi": "hne",
"kannada": "kn",
"magahi": "mag",
"maithili": "mai",
"marathi": "mr",
"telugu": "te"
}
app = FastAPI()
@app.get("/Get_Inference")
async def Inference(text : str, lang : str, speaker : str):
if not text or not lang or not speaker:
return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
spk = speaker.lower()
lan = lang.lower()
if spk not in ref_path:
return JSONResponse({"comment" : "Speaker not present in the system."}, status_code = 422)
if lan not in languageCODE or lan not in languageCODE.values():
return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)
wav = np.array(tts.tts(text=text, speaker_wav=ref_path[speaker], language = languageCODE[lan] if lan not in languageCODE.values() else lan))
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
wav_buffer = BytesIO()
write(wav_buffer, sample_rate, wav_norm)
wav_buffer.seek(0)
wav_buffer.name = lang + "_" + speaker + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
return Response(wav_buffer.read())
def start_server():
print('Starting Server...')
uvicorn.run(
"API_Main:app",
host = "0.0.0.0",
port = 8080,
workers = 1,
log_level="debug",
reload=False,
)
if __name__ == "__main__":
start_server()