|
|
import torch |
|
|
import string |
|
|
import random |
|
|
import uvicorn |
|
|
import numpy as np |
|
|
from io import BytesIO |
|
|
from TTS.api import TTS |
|
|
from fastapi import FastAPI |
|
|
from scipy.io.wavfile import write |
|
|
from fastapi.responses import Response, JSONResponse |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
MODEL_PATH = "models/best_model.pth" |
|
|
CONFIG_PATH = "models/config.json" |
|
|
|
|
|
print(f"Loading model") |
|
|
tts = TTS( |
|
|
model_path=MODEL_PATH, |
|
|
config_path=CONFIG_PATH, |
|
|
progress_bar=False, |
|
|
).to(device) |
|
|
|
|
|
sample_rate = 22050 |
|
|
|
|
|
ref_path = { |
|
|
"chhattisgarhi_male" : "reference_audio/chhattisgarhi_male.wav", |
|
|
"chhattisgarhi_female" : "reference_audio/chhattisgarhi_female.wav", |
|
|
"kannada_male" : "reference_audio/kannada_male.wav", |
|
|
"kannada_female" : "reference_audio/kannada_female.wav", |
|
|
"maithili_male" : "reference_audio/maithili_male.wav", |
|
|
"maithili_female" : "reference_audio/maithili_female.wav", |
|
|
"telugu_male" : "reference_audio/telugu_male.wav", |
|
|
"telugu_female" : "reference_audio/telugu_female.wav", |
|
|
"bengali_male" : "reference_audio/bengali_male.wav", |
|
|
"bengali_female" : "reference_audio/bengali_female.wav", |
|
|
"bhojpuri_male" : "reference_audio/bhojpuri_male.wav", |
|
|
"bhojpuri_female" : "reference_audio/bhojpuri_female.wav", |
|
|
"marathi_female" : "reference_audio/marathi_female.wav", |
|
|
"marathi_male" : "reference_audio/marathi_male.wav", |
|
|
"gujarati_male" : "reference_audio/gujarati_male.wav", |
|
|
"gujarati_female" : "reference_audio/gujarati_female.wav", |
|
|
"hindi_male" : "reference_audio/hindi_male.wav", |
|
|
"hindi_female" : "reference_audio/hindi_female.wav", |
|
|
"magahi_female" : "reference_audio/magahi_female.wav", |
|
|
"magahi_male" : "reference_audio/magahi_male.wav", |
|
|
"english_female" : "reference_audio/english_female.wav", |
|
|
"english_male" : "reference_audio/english_male.wav", |
|
|
} |
|
|
|
|
|
languageCODE = { |
|
|
"bhojpuri": "bho", |
|
|
"bengali": "bn", |
|
|
"english": "en", |
|
|
"gujarati": "gu", |
|
|
"hindi": "hi", |
|
|
"chhattisgarhi": "hne", |
|
|
"kannada": "kn", |
|
|
"magahi": "mag", |
|
|
"maithili": "mai", |
|
|
"marathi": "mr", |
|
|
"telugu": "te" |
|
|
} |
|
|
|
|
|
app = FastAPI() |
|
|
@app.get("/Get_Inference") |
|
|
async def Inference(text : str, lang : str, speaker : str): |
|
|
|
|
|
if not text or not lang or not speaker: |
|
|
return JSONResponse({"comment" : "Missing Field."}, status_code = 422) |
|
|
|
|
|
spk = speaker.lower() |
|
|
lan = lang.lower() |
|
|
|
|
|
if spk not in ref_path: |
|
|
return JSONResponse({"comment" : "Speaker not present in the system."}, status_code = 422) |
|
|
|
|
|
if lan not in languageCODE or lan not in languageCODE.values(): |
|
|
return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422) |
|
|
|
|
|
|
|
|
wav = np.array(tts.tts(text=text, speaker_wav=ref_path[speaker], language = languageCODE[lan] if lan not in languageCODE.values() else lan)) |
|
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) |
|
|
wav_norm = wav_norm.astype(np.int16) |
|
|
|
|
|
wav_buffer = BytesIO() |
|
|
write(wav_buffer, sample_rate, wav_norm) |
|
|
wav_buffer.seek(0) |
|
|
wav_buffer.name = lang + "_" + speaker + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav" |
|
|
return Response(wav_buffer.read()) |
|
|
|
|
|
|
|
|
def start_server(): |
|
|
print('Starting Server...') |
|
|
|
|
|
uvicorn.run( |
|
|
"API_Main:app", |
|
|
host = "0.0.0.0", |
|
|
port = 8080, |
|
|
workers = 1, |
|
|
log_level="debug", |
|
|
reload=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
start_server() |