SpireLab commited on
Commit
6b1e70c
·
verified ·
1 Parent(s): a17506e

Update API_Main.py

Browse files
Files changed (1) hide show
  1. API_Main.py +23 -33
API_Main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import string
3
  import random
@@ -5,10 +7,11 @@ import uvicorn
5
  import numpy as np
6
  from io import BytesIO
7
  from TTS.api import TTS
8
- from fastapi import FastAPI
9
  from scipy.io.wavfile import write
10
  from fastapi.responses import Response, JSONResponse
11
 
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  print(f"Using device: {device}", flush = True)
@@ -25,31 +28,6 @@ tts = TTS(
25
 
26
  sample_rate = 22050
27
 
28
- ref_path = {
29
- "chhattisgarhi_male" : "reference_audio/chhattisgarhi_male.wav",
30
- "chhattisgarhi_female" : "reference_audio/chhattisgarhi_female.wav",
31
- "kannada_male" : "reference_audio/kannada_male.wav",
32
- "kannada_female" : "reference_audio/kannada_female.wav",
33
- "maithili_male" : "reference_audio/maithili_male.wav",
34
- "maithili_female" : "reference_audio/maithili_female.wav",
35
- "telugu_male" : "reference_audio/telugu_male.wav",
36
- "telugu_female" : "reference_audio/telugu_female.wav",
37
- "bengali_male" : "reference_audio/bengali_male.wav",
38
- "bengali_female" : "reference_audio/bengali_female.wav",
39
- "bhojpuri_male" : "reference_audio/bhojpuri_male.wav",
40
- "bhojpuri_female" : "reference_audio/bhojpuri_female.wav",
41
- "marathi_female" : "reference_audio/marathi_female.wav",
42
- "marathi_male" : "reference_audio/marathi_male.wav",
43
- "gujarati_male" : "reference_audio/gujarati_male.wav",
44
- "gujarati_female" : "reference_audio/gujarati_female.wav",
45
- "hindi_male" : "reference_audio/hindi_male.wav",
46
- "hindi_female" : "reference_audio/hindi_female.wav",
47
- "magahi_female" : "reference_audio/magahi_female.wav",
48
- "magahi_male" : "reference_audio/magahi_male.wav",
49
- "english_female" : "reference_audio/english_female.wav",
50
- "english_male" : "reference_audio/english_male.wav",
51
- }
52
-
53
  languageCODE = {
54
  "bhojpuri": "bho",
55
  "bengali": "bn",
@@ -70,29 +48,41 @@ def Is_alive():
70
  return {"message" : "Server is Live"}
71
 
72
  @app.get("/Get_Inference")
73
- async def Inference(text : str, lang : str, speaker : str):
74
 
75
- if not text or not lang or not speaker:
76
  return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
77
 
78
- spk = speaker.lower()
79
  lan = lang.lower()
80
 
81
- if spk not in ref_path:
82
- return JSONResponse({"comment" : "Speaker not present in the system."}, status_code = 422)
83
 
84
  if lan not in languageCODE or lan not in languageCODE.values():
85
  return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)
86
 
 
 
 
 
87
 
88
- wav = np.array(tts.tts(text=text, speaker_wav=ref_path[speaker], language = languageCODE[lan] if lan not in languageCODE.values() else lan))
 
 
 
 
 
 
89
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
90
  wav_norm = wav_norm.astype(np.int16)
91
 
92
  wav_buffer = BytesIO()
93
  write(wav_buffer, sample_rate, wav_norm)
94
  wav_buffer.seek(0)
95
- wav_buffer.name = lang + "_" + speaker + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
 
 
 
96
  return Response(wav_buffer.read())
97
 
98
 
 
1
+ import os
2
+ import wave
3
  import torch
4
  import string
5
  import random
 
7
  import numpy as np
8
  from io import BytesIO
9
  from TTS.api import TTS
10
+ from fastapi import FastAPI, UploadFile
11
  from scipy.io.wavfile import write
12
  from fastapi.responses import Response, JSONResponse
13
 
14
+ os.makedirs("temp/", exist_ok = True)
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"Using device: {device}", flush = True)
 
28
 
29
  sample_rate = 22050
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  languageCODE = {
32
  "bhojpuri": "bho",
33
  "bengali": "bn",
 
48
  return {"message" : "Server is Live"}
49
 
50
  @app.get("/Get_Inference")
51
+ async def Inference(text : str, lang : str, speaker_wav : UploadFile):
52
 
53
+ if not text or not lang or not speaker_wav:
54
  return JSONResponse({"comment" : "Missing Field."}, status_code = 422)
55
 
 
56
  lan = lang.lower()
57
 
58
+ if not speaker_wav:
59
+ return JSONResponse({"comment" : "Speaker file not provided."}, status_code = 422)
60
 
61
  if lan not in languageCODE or lan not in languageCODE.values():
62
  return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422)
63
 
64
+ speaker_wav_filename = "temp/" + random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) + "_" + speaker_wav.filename
65
+
66
+ with open(speaker_wav_filename , "rb") as wavFile:
67
+ wavFile.write(await speaker_wav.file.read())
68
 
69
+ try:
70
+ with wave.open(speaker_wav_filename) as temper:
71
+ pass
72
+ except:
73
+ return JSONResponse({"comment" : "Audio file format not supported."}, status_code = 422)
74
+
75
+ wav = np.array(tts.tts(text=text, speaker_wav = speaker_wav_filename, language = languageCODE[lan]))
76
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
77
  wav_norm = wav_norm.astype(np.int16)
78
 
79
  wav_buffer = BytesIO()
80
  write(wav_buffer, sample_rate, wav_norm)
81
  wav_buffer.seek(0)
82
+ wav_buffer.name = lang + "_" + speaker_wav.filename + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav"
83
+
84
+ os.remove(speaker_wav_filename)
85
+
86
  return Response(wav_buffer.read())
87
 
88