import os import wave import torch import string import random import uvicorn import numpy as np from io import BytesIO from TTS.api import TTS from fastapi import FastAPI, UploadFile from scipy.io.wavfile import write from fastapi.responses import Response, JSONResponse os.makedirs("temp/", exist_ok = True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}", flush = True) MODEL_PATH = "models/best_model.pth" CONFIG_PATH = "models/config.json" print(f"Loading model", flush = True) tts = TTS( model_path=MODEL_PATH, config_path=CONFIG_PATH, progress_bar=False, ).to(device) sample_rate = 22050 languageCODE = { "bhojpuri": "bho", "bengali": "bn", "english": "en", "gujarati": "gu", "hindi": "hi", "chhattisgarhi": "hne", "kannada": "kn", "magahi": "mag", "maithili": "mai", "marathi": "mr", "telugu": "te" } app = FastAPI() @app.get("/") def Is_alive(): return {"message" : "Server is Live"} @app.get("/Get_Inference") async def Inference(text : str, lang : str, speaker_wav : UploadFile): if not text or not lang or not speaker_wav: return JSONResponse({"comment" : "Missing Field."}, status_code = 422) lan = lang.lower() if not speaker_wav: return JSONResponse({"comment" : "Speaker file not provided."}, status_code = 422) if lan not in languageCODE: if lan not in languageCODE.values(): return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422) random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) speaker_wav_filename = "temp/" + random_string + "_" + speaker_wav.filename with open(speaker_wav_filename , "wb") as wavFile: wavFile.write(speaker_wav.file.read()) try: with wave.open(speaker_wav_filename, "r") as temper: pass except: return JSONResponse({"comment" : "Audio file format not supported."}, status_code = 422) wav = np.array(tts.tts(text=text, speaker_wav = speaker_wav_filename, language = languageCODE[lan])) wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) wav_norm = wav_norm.astype(np.int16) wav_buffer = BytesIO() write(wav_buffer, sample_rate, wav_norm) wav_buffer.seek(0) wav_buffer.name = lang + "_" + speaker_wav.filename + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav" os.remove(speaker_wav_filename) return Response(wav_buffer.read()) def start_server(): print('Starting Server...') uvicorn.run( "API_Main:app", host = "0.0.0.0", port = 8080, log_level="debug", reload=False, ) if __name__ == "__main__": start_server()