File size: 3,629 Bytes
f15e280
 
 
 
 
 
 
 
 
 
 
 
 
452b9e6
f15e280
 
 
 
452b9e6
f15e280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452b9e6
 
 
 
f15e280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import string
import random
import uvicorn
import numpy as np
from io import BytesIO
from TTS.api import TTS
from fastapi import FastAPI
from scipy.io.wavfile import write
from fastapi.responses import Response, JSONResponse


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}", flush = True)

MODEL_PATH = "models/best_model.pth"
CONFIG_PATH = "models/config.json"

print(f"Loading model", flush = True)
tts = TTS(
    model_path=MODEL_PATH,
    config_path=CONFIG_PATH,
    progress_bar=False,
).to(device)

sample_rate = 22050

ref_path = {
    "chhattisgarhi_male" : "reference_audio/chhattisgarhi_male.wav",
    "chhattisgarhi_female" : "reference_audio/chhattisgarhi_female.wav",
    "kannada_male" : "reference_audio/kannada_male.wav",
    "kannada_female" : "reference_audio/kannada_female.wav",
    "maithili_male" : "reference_audio/maithili_male.wav",
    "maithili_female" : "reference_audio/maithili_female.wav",
    "telugu_male" : "reference_audio/telugu_male.wav",
    "telugu_female" : "reference_audio/telugu_female.wav",
    "bengali_male" : "reference_audio/bengali_male.wav",
    "bengali_female" : "reference_audio/bengali_female.wav",
    "bhojpuri_male" : "reference_audio/bhojpuri_male.wav",
    "bhojpuri_female" : "reference_audio/bhojpuri_female.wav",
    "marathi_female" : "reference_audio/marathi_female.wav",
    "marathi_male" : "reference_audio/marathi_male.wav",
    "gujarati_male" : "reference_audio/gujarati_male.wav",
    "gujarati_female" : "reference_audio/gujarati_female.wav",
    "hindi_male" : "reference_audio/hindi_male.wav",
    "hindi_female" : "reference_audio/hindi_female.wav",
    "magahi_female" : "reference_audio/magahi_female.wav",
    "magahi_male" : "reference_audio/magahi_male.wav",
    "english_female" : "reference_audio/english_female.wav",
    "english_male" : "reference_audio/english_male.wav",
}

languageCODE = {
    "bhojpuri": "bho",
    "bengali": "bn",
    "english": "en",
    "gujarati": "gu",
    "hindi": "hi",
    "chhattisgarhi": "hne",
    "kannada": "kn",
    "magahi": "mag",
    "maithili": "mai",
    "marathi": "mr",
    "telugu": "te"
}

app = FastAPI()
@app.get("/")
def Is_alive():
    return {"message" : "Server is Live"}

@app.get("/Get_Inference")
async def Inference(text : str, lang : str, speaker : str):

    if not text or not lang or not speaker:
        return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
    
    spk = speaker.lower()
    lan = lang.lower()

    if spk not in ref_path:
        return JSONResponse({"comment" : "Speaker not present in the system."}, status_code = 422)
    
    if  lan not in languageCODE or lan not in languageCODE.values():
        return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)


    wav = np.array(tts.tts(text=text, speaker_wav=ref_path[speaker], language = languageCODE[lan] if lan not in languageCODE.values() else lan))
    wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
    wav_norm = wav_norm.astype(np.int16)
    
    wav_buffer = BytesIO()
    write(wav_buffer, sample_rate, wav_norm)
    wav_buffer.seek(0)
    wav_buffer.name = lang + "_" + speaker + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
    return Response(wav_buffer.read())


def start_server():
    print('Starting Server...')

    uvicorn.run(
        "API_Main:app",
        host = "0.0.0.0",
        port = 8080,
        log_level="debug",
        reload=False,
    )

if __name__ == "__main__":
    start_server()