File size: 3,549 Bytes
f15e280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import string
import random
import uvicorn
import numpy as np
from io import BytesIO
from TTS.api import TTS
from fastapi import FastAPI
from scipy.io.wavfile import write
from fastapi.responses import Response, JSONResponse


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

MODEL_PATH = "models/best_model.pth"
CONFIG_PATH = "models/config.json"

print(f"Loading model")
tts = TTS(
    model_path=MODEL_PATH,
    config_path=CONFIG_PATH,
    progress_bar=False,
).to(device)

sample_rate = 22050

ref_path = {
    "chhattisgarhi_male" : "reference_audio/chhattisgarhi_male.wav",
    "chhattisgarhi_female" : "reference_audio/chhattisgarhi_female.wav",
    "kannada_male" : "reference_audio/kannada_male.wav",
    "kannada_female" : "reference_audio/kannada_female.wav",
    "maithili_male" : "reference_audio/maithili_male.wav",
    "maithili_female" : "reference_audio/maithili_female.wav",
    "telugu_male" : "reference_audio/telugu_male.wav",
    "telugu_female" : "reference_audio/telugu_female.wav",
    "bengali_male" : "reference_audio/bengali_male.wav",
    "bengali_female" : "reference_audio/bengali_female.wav",
    "bhojpuri_male" : "reference_audio/bhojpuri_male.wav",
    "bhojpuri_female" : "reference_audio/bhojpuri_female.wav",
    "marathi_female" : "reference_audio/marathi_female.wav",
    "marathi_male" : "reference_audio/marathi_male.wav",
    "gujarati_male" : "reference_audio/gujarati_male.wav",
    "gujarati_female" : "reference_audio/gujarati_female.wav",
    "hindi_male" : "reference_audio/hindi_male.wav",
    "hindi_female" : "reference_audio/hindi_female.wav",
    "magahi_female" : "reference_audio/magahi_female.wav",
    "magahi_male" : "reference_audio/magahi_male.wav",
    "english_female" : "reference_audio/english_female.wav",
    "english_male" : "reference_audio/english_male.wav",
}

languageCODE = {
    "bhojpuri": "bho",
    "bengali": "bn",
    "english": "en",
    "gujarati": "gu",
    "hindi": "hi",
    "chhattisgarhi": "hne",
    "kannada": "kn",
    "magahi": "mag",
    "maithili": "mai",
    "marathi": "mr",
    "telugu": "te"
}

app = FastAPI()
@app.get("/Get_Inference")
async def Inference(text : str, lang : str, speaker : str):

    if not text or not lang or not speaker:
        return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
    
    spk = speaker.lower()
    lan = lang.lower()

    if spk not in ref_path:
        return JSONResponse({"comment" : "Speaker not present in the system."}, status_code = 422)
    
    if  lan not in languageCODE or lan not in languageCODE.values():
        return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)


    wav = np.array(tts.tts(text=text, speaker_wav=ref_path[speaker], language = languageCODE[lan] if lan not in languageCODE.values() else lan))
    wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
    wav_norm = wav_norm.astype(np.int16)
    
    wav_buffer = BytesIO()
    write(wav_buffer, sample_rate, wav_norm)
    wav_buffer.seek(0)
    wav_buffer.name = lang + "_" + speaker + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
    return Response(wav_buffer.read())


def start_server():
    print('Starting Server...')

    uvicorn.run(
        "API_Main:app",
        host = "0.0.0.0",
        port = 8080,
        workers = 1,
        log_level="debug",
        reload=False,
    )

if __name__ == "__main__":
    start_server()