najiib9 commited on
Commit
1b2dbba
·
verified ·
1 Parent(s): 7a4e652

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment settings
2
+ import os
3
+ os.environ["HF_HOME"] = "/tmp"
4
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
5
+ os.environ["TORCH_HOME"] = "/tmp"
6
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
7
+
8
+ import io
9
+ import re
10
+ import math
11
+ import numpy as np
12
+ import scipy.io.wavfile
13
+ import torch
14
+ from fastapi import FastAPI, Query
15
+ from fastapi.responses import StreamingResponse
16
+ from pydantic import BaseModel
17
+ from transformers import VitsModel, AutoTokenizer
18
+
19
+ app = FastAPI()
20
+
21
+ model = VitsModel.from_pretrained("najiib9/somali_tts_final_model")
22
+ tokenizer = AutoTokenizer.from_pretrained("najiib9/somali_tts_final_model")
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model.to(device)
26
+ model.eval()
27
+
28
+ number_words = {
29
+ 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
30
+ 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
31
+ 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex",
32
+ 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix",
33
+ 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal",
34
+ 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
35
+ 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
36
+ 100: "boqol", 1000: "kun"
37
+ }
38
+
39
+ shortcut_map = {
40
+ "asc": "asalaamu caleykum",
41
+ "wcs": "wacaleykum salaam",
42
+ "fcn": "fiican",
43
+ "xld": "xaaladda ka waran",
44
+ "kwrn": "kawaran",
45
+ "scw": "salalaahu caleyhi wa salam",
46
+ "alx": "alxamdu lilaahi",
47
+ "m.a": "maasha allah",
48
+ "sthy": "side tahey",
49
+ "sxp": "saaxiib"
50
+ }
51
+
52
+ country_map = {
53
+ "somalia": "Soomaaliya",
54
+ "ethiopia": "Itoobiya",
55
+ "kenya": "Kenya",
56
+ "djibouti": "Jabuuti",
57
+ "sudan": "Suudaan",
58
+ "Yeman": "yemaan",
59
+ "uganda": "Ugaandha",
60
+ "tanzania": "Tansaaniya",
61
+ "egypt": "Masar",
62
+ "libya": "Liibiya",
63
+ "algeria": "Aljeeriya",
64
+ "morocco": "Morooko",
65
+ "tunisia": "Tuniisiya",
66
+ "eritrea": "Eriteriya",
67
+ "malawi": "Malaawi",
68
+ "English": "ingiriis",
69
+ "Spain": "isbeen",
70
+ "Brazil": "baraasiil",
71
+ "niger": "Niyjer",
72
+ "Italy": "itaaliya",
73
+ "united states": "Maraykanka",
74
+ "china": "Shiinaha",
75
+ "india": "Hindiya",
76
+ "russia": "Ruushka",
77
+ "Saudi Arabia": "Sucuudi Carabiya",
78
+ "germany": "Jarmalka",
79
+ "france": "Faransiiska",
80
+ "japan": "Jabaan",
81
+ "canada": "Kanada",
82
+ "australia": "Australia"
83
+ }
84
+
85
+ def number_to_words(number):
86
+ number = int(number)
87
+ if number < 20:
88
+ return number_words[number]
89
+ elif number < 100:
90
+ tens, unit = divmod(number, 10)
91
+ return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "")
92
+ elif number < 1000:
93
+ hundreds, remainder = divmod(number, 100)
94
+ part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
95
+ if remainder:
96
+ part += " iyo " + number_to_words(remainder)
97
+ return part
98
+ elif number < 1000000:
99
+ thousands, remainder = divmod(number, 1000)
100
+ words = [number_to_words(thousands) + " kun" if thousands > 1 else "kun"]
101
+ if remainder:
102
+ words.append("iyo " + number_to_words(remainder))
103
+ return " ".join(words)
104
+ elif number < 1000000000:
105
+ millions, remainder = divmod(number, 1000000)
106
+ words = [number_to_words(millions) + " milyan" if millions > 1 else "milyan"]
107
+ if remainder:
108
+ words.append(number_to_words(remainder))
109
+ return " ".join(words)
110
+ else:
111
+ return str(number)
112
+
113
+ def normalize_text(text):
114
+ text = re.sub(r'(?i)(?<!\w)zamzam(?!\w)', 'samsam', text)
115
+
116
+ def replace_shortcuts(match):
117
+ word = match.group(0).lower()
118
+ return shortcut_map.get(word, word)
119
+ pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in shortcut_map.keys()) + r')\b', re.IGNORECASE)
120
+ text = pattern.sub(replace_shortcuts, text)
121
+
122
+ def replace_countries(match):
123
+ word = match.group(0).lower()
124
+ return country_map.get(word, word)
125
+ country_pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in country_map.keys()) + r')\b', re.IGNORECASE)
126
+ text = country_pattern.sub(replace_countries, text)
127
+
128
+ text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text)
129
+ text = re.sub(r'\.\d+', '', text)
130
+
131
+ def replace_num(match):
132
+ return number_to_words(match.group())
133
+ text = re.sub(r'\d+', replace_num, text)
134
+
135
+ symbol_map = {
136
+ '$': 'doolar',
137
+ '=': 'egwal',
138
+ '+': 'balaas',
139
+ '#': 'haash'
140
+ }
141
+ for sym, word in symbol_map.items():
142
+ text = text.replace(sym, ' ' + word + ' ')
143
+
144
+ text = text.replace("KH", "qa").replace("Z", "S")
145
+ text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
146
+
147
+ if re.search(r'(?i)(zamzam|samsam)[\s\.,!?]*$', text.strip()):
148
+ text += " m"
149
+
150
+ return text
151
+
152
+ def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
153
+ np_waveform = waveform.cpu().numpy()
154
+ if np_waveform.ndim == 3:
155
+ np_waveform = np_waveform[0]
156
+ if np_waveform.ndim == 2:
157
+ np_waveform = np_waveform.mean(axis=0)
158
+ np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
159
+ pcm_waveform = (np_waveform * 32767).astype(np.int16)
160
+ buf = io.BytesIO()
161
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
162
+ buf.seek(0)
163
+ return buf.read()
164
+
165
+ class TextIn(BaseModel):
166
+ inputs: str
167
+
168
+ @app.post("/synthesize")
169
+ async def synthesize_post(data: TextIn):
170
+ paragraphs = [p.strip() for p in data.inputs.split('\n') if p.strip()]
171
+ sample_rate = getattr(model.config, "sampling_rate", 22050)
172
+ all_waveforms = []
173
+
174
+ for paragraph in paragraphs:
175
+ normalized = normalize_text(paragraph)
176
+ inputs = tokenizer(normalized, return_tensors="pt").to(device)
177
+ with torch.no_grad():
178
+ output = model(**inputs)
179
+ waveform = (
180
+ output.waveform if hasattr(output, "waveform") else
181
+ output["waveform"] if isinstance(output, dict) and "waveform" in output else
182
+ output[0] if isinstance(output, (tuple, list)) else
183
+ None
184
+ )
185
+ if waveform is None:
186
+ continue
187
+ all_waveforms.append(waveform)
188
+ silence = torch.zeros(1, sample_rate).to(waveform.device)
189
+ all_waveforms.append(silence)
190
+
191
+ if not all_waveforms:
192
+ return {"error": "No audio generated."}
193
+
194
+ final_waveform = torch.cat(all_waveforms, dim=-1)
195
+ wav_bytes = waveform_to_wav_bytes(final_waveform, sample_rate=sample_rate)
196
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")
197
+
198
+ @app.get("/synthesize")
199
+ async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)):
200
+ if test:
201
+ paragraphs = text.count("\n") + 1
202
+ duration_s = paragraphs * 6
203
+ sample_rate = 22050
204
+ t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
205
+ freq = 440
206
+ waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
207
+ pcm_waveform = (waveform * 32767).astype(np.int16)
208
+ buf = io.BytesIO()
209
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
210
+ buf.seek(0)
211
+ return StreamingResponse(buf, media_type="audio/wav")
212
+
213
+ normalized = normalize_text(text)
214
+ inputs = tokenizer(normalized, return_tensors="pt").to(device)
215
+ with torch.no_grad():
216
+ output = model(**inputs)
217
+ waveform = (
218
+ output.waveform if hasattr(output, "waveform") else
219
+ output["waveform"] if isinstance(output, dict) and "waveform" in output else
220
+ output[0] if isinstance(output, (tuple, list)) else
221
+ None
222
+ )
223
+ if waveform is None:
224
+ return {"error": "Waveform not found in model output"}
225
+ sample_rate = getattr(model.config, "sampling_rate", 22050)
226
+ wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
227
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")