Somalitts commited on
Commit
2860b2a
·
verified ·
1 Parent(s): 767e58a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -35
app.py CHANGED
@@ -10,71 +10,177 @@ from speechbrain.pretrained import EncoderClassifier
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
14
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
15
- model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
 
16
 
17
- # Speaker encoder
18
  speaker_model = EncoderClassifier.from_hparams(
19
  source="speechbrain/spkrec-xvect-voxceleb",
20
  run_opts={"device": device},
21
- savedir="/tmp/spk_model"
22
  )
23
 
24
- # Load female embedding only
25
- def get_embedding(wav_path, pt_path):
26
- if os.path.exists(pt_path):
27
- return torch.load(pt_path).to(device)
28
- audio, sr = torchaudio.load(wav_path)
 
29
  audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
30
  with torch.no_grad():
31
  emb = speaker_model.encode_batch(audio)
32
  emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
33
- torch.save(emb.cpu(), pt_path)
34
- return emb
35
 
36
- embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
37
-
38
- # Text normalization
39
  number_words = {
40
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
41
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
 
 
 
42
  20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
43
  60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
44
  100: "boqol", 1000: "kun"
45
  }
46
 
47
- def number_to_words(n):
48
- if n < 20:
49
- return number_words.get(n, str(n))
50
- elif n < 100:
51
- tens, unit = divmod(n, 10)
52
- return number_words[tens * 10] + (" " + number_words[unit] if unit else "")
53
- elif n < 1000:
54
- hundreds, rem = divmod(n, 100)
55
- return (number_words[hundreds] + " boqol" if hundreds > 1 else "boqol") + (" " + number_to_words(rem) if rem else "")
56
- elif n < 1_000_000:
57
- th, rem = divmod(n, 1000)
58
- return (number_to_words(th) + " kun") + (" " + number_to_words(rem) if rem else "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  else:
60
- return str(n)
61
 
62
  def replace_numbers_with_words(text):
63
- return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
 
 
 
64
 
65
  def normalize_text(text):
66
  text = text.lower()
67
  text = replace_numbers_with_words(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  text = re.sub(r'[^\w\s]', '', text)
 
69
  return text
70
 
71
- # Gradio interface
 
 
 
 
 
 
72
  iface = gr.Interface(
73
- fn=tts,
74
- inputs=gr.Textbox(label="Geli qoraalka af-soomaali", lines=10, placeholder="Ku qor qoraalka..."),
75
- outputs=gr.Audio(label="Codka la abuuray", type="filepath"),
76
- title="Somali TTS - Qaybo Dheer & Cod Gaar ah",
77
- description="Qoraal dheer ayaad gali kartaa oo lagu kala jarayo paragraphs. Waxaa lagu abuurayaa cod TTS af Soomaali ah."
78
  )
79
 
80
  iface.launch()
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load models
14
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
15
+ model = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad").to(device)
16
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
17
 
 
18
  speaker_model = EncoderClassifier.from_hparams(
19
  source="speechbrain/spkrec-xvect-voxceleb",
20
  run_opts={"device": device},
21
+ savedir="./spk_model"
22
  )
23
 
24
+ # Speaker embedding
25
+ EMB_PATH = "speaker_embedding.pt"
26
+ if os.path.exists(EMB_PATH):
27
+ speaker_embedding = torch.load(EMB_PATH).to(device)
28
+ else:
29
+ audio, sr = torchaudio.load("1.wav")
30
  audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
31
  with torch.no_grad():
32
  emb = speaker_model.encode_batch(audio)
33
  emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
34
+ torch.save(emb.cpu(), EMB_PATH)
35
+ speaker_embedding = emb
36
 
37
+ # Number conversion (Somali)
 
 
38
  number_words = {
39
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
40
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
41
+ 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex",
42
+ 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix",
43
+ 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal",
44
  20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
45
  60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
46
  100: "boqol", 1000: "kun"
47
  }
48
 
49
+ shortcut_map = {
50
+ "asc": "asalaamu caleykum",
51
+ "wcs": "wacaleykum salaam",
52
+ "fcn": "fiican",
53
+ "xld": "xaaladda ka waran",
54
+ "kwrn": "kawaran",
55
+ "scw": "salalaahu caleyhi wa salam",
56
+ "alx": "alxamdu lilaahi",
57
+ "m.a": "maasha allah",
58
+ "sthy": "side tahey",
59
+ "sxp": "saaxiib"
60
+ }
61
+
62
+ country_map = {
63
+ "somalia": "Soomaaliya",
64
+ "ethiopia": "Itoobiya",
65
+ "kenya": "Kenya",
66
+ "djibouti": "Jabuuti",
67
+ "sudan": "Suudaan",
68
+ "Yeman": "yemaan",
69
+ "uganda": "Ugaandha",
70
+ "tanzania": "Tansaaniya",
71
+ "egypt": "Masar",
72
+ "libya": "Liibiya",
73
+ "algeria": "Aljeeriya",
74
+ "morocco": "Morooko",
75
+ "tunisia": "Tuniisiya",
76
+ "eritrea": "Eriteriya",
77
+ "malawi": "Malaawi",
78
+ "English": "ingiriis",
79
+ "Spain": "isbeen",
80
+ "Brazil": "baraasiil",
81
+ "niger": "Niyjer",
82
+ "Italy": "itaaliya",
83
+ "united states": "Maraykanka",
84
+ "china": "Shiinaha",
85
+ "india": "Hindiya",
86
+ "russia": "Ruushka",
87
+ "Saudi Arabia": "Sucuudi Carabiya",
88
+ "germany": "Jarmalka",
89
+ "france": "Faransiiska",
90
+ "japan": "Jabaan",
91
+ "canada": "Kanada",
92
+ "australia": "Australia"
93
+ }
94
+
95
+ def number_to_words(number):
96
+ number = int(number)
97
+ if number < 20:
98
+ return number_words[number]
99
+ elif number < 100:
100
+ tens, unit = divmod(number, 10)
101
+ return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "")
102
+ elif number < 1000:
103
+ hundreds, remainder = divmod(number, 100)
104
+ part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
105
+ if remainder:
106
+ part += " iyo " + number_to_words(remainder)
107
+ return part
108
+ elif number < 1000000:
109
+ thousands, remainder = divmod(number, 1000)
110
+ words = []
111
+ if thousands == 1:
112
+ words.append("kun")
113
+ else:
114
+ words.append(number_to_words(thousands) + " kun")
115
+ if remainder:
116
+ words.append("iyo " + number_to_words(remainder))
117
+ return " ".join(words)
118
+ elif number < 1000000000:
119
+ millions, remainder = divmod(number, 1000000)
120
+ words = []
121
+ if millions == 1:
122
+ words.append("milyan")
123
+ else:
124
+ words.append(number_to_words(millions) + " milyan")
125
+ if remainder:
126
+ words.append(number_to_words(remainder))
127
+ return " ".join(words)
128
  else:
129
+ return str(number)
130
 
131
  def replace_numbers_with_words(text):
132
+ def replace(match):
133
+ number = int(match.group())
134
+ return number_to_words(number)
135
+ return re.sub(r'\b\d+\b', replace, text)
136
 
137
  def normalize_text(text):
138
  text = text.lower()
139
  text = replace_numbers_with_words(text)
140
+
141
+ def replace_shortcuts(match):
142
+ word = match.group(0).lower()
143
+ return shortcut_map.get(word, word)
144
+
145
+ pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in shortcut_map.keys()) + r')\b', re.IGNORECASE)
146
+ text = pattern.sub(replace_shortcuts, text)
147
+
148
+ def replace_countries(match):
149
+ word = match.group(0).lower()
150
+ return country_map.get(word, word)
151
+
152
+ country_pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in country_map.keys()) + r')\b', re.IGNORECASE)
153
+ text = country_pattern.sub(replace_countries, text)
154
+
155
+ text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text)
156
+ text = re.sub(r'\.\d+', '', text)
157
+
158
+ symbol_map = {
159
+ '$': 'doolar',
160
+ '=': 'egwal',
161
+ '+': 'balaas',
162
+ '#': 'haash'
163
+ }
164
+ for sym, word in symbol_map.items():
165
+ text = text.replace(sym, ' ' + word + ' ')
166
+
167
  text = re.sub(r'[^\w\s]', '', text)
168
+
169
  return text
170
 
171
+ def text_to_speech(text):
172
+ text = normalize_text(text)
173
+ inputs = processor(text=text, return_tensors="pt").to(device)
174
+ with torch.no_grad():
175
+ speech = model.generate_speech(inputs["input_ids"], speaker_embedding.unsqueeze(0), vocoder=vocoder)
176
+ return (16000, speech.cpu().numpy())
177
+
178
  iface = gr.Interface(
179
+ fn=text_to_speech,
180
+ inputs=gr.Textbox(label="Geli qoraalka af-soomaali"),
181
+ outputs=gr.Audio(label="Codka la abuuray", type="numpy"),
182
+ title="Somali TTS",
183
+ description="TTS Soomaaliyeed oo la adeegsaday cod gaar ah (1.wav)"
184
  )
185
 
186
  iface.launch()