HusseinBashir commited on
Commit
52d260b
·
verified ·
1 Parent(s): cd9dd91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -5,13 +5,14 @@ import scipy.io.wavfile
5
  from transformers import VitsModel, AutoTokenizer
6
  import re
7
 
8
- # Load fine-tuned model from Hugging Face Hub or local path
9
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
10
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
13
  model.eval()
14
 
 
15
  number_words = {
16
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
17
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
@@ -67,28 +68,58 @@ def number_to_words(number):
67
  return str(number)
68
 
69
  def normalize_text(text):
70
- numbers = re.findall(r'\d+', text)
71
- for num in numbers:
72
- text = text.replace(num, number_to_words(num))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  text = text.replace("KH", "qa").replace("Z", "S")
74
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
75
  text = text.replace("ZamZam", "SamSam")
76
  return text
77
 
78
  def tts(text):
79
- text = normalize_text(text)
80
- inputs = tokenizer(text, return_tensors="pt").to(device)
81
- with torch.no_grad():
82
- waveform = model(**inputs).waveform.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  filename = "output.wav"
84
- scipy.io.wavfile.write(filename, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
85
  return filename
86
 
 
87
  gr.Interface(
88
  fn=tts,
89
- inputs=gr.Textbox(label="Geli qoraal Soomaali ah"),
90
  outputs=gr.Audio(label="Codka TTS"),
91
  title="Somali TTS",
92
- description="Ku qor qoraal Soomaaliyeed si aad u maqasho cod dabiici ah.",
93
  ).launch()
94
-
 
5
  from transformers import VitsModel, AutoTokenizer
6
  import re
7
 
8
+ # Load model and tokenizer
9
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
10
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
13
  model.eval()
14
 
15
+ # Numbers in Somali
16
  number_words = {
17
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
18
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
 
68
  return str(number)
69
 
70
  def normalize_text(text):
71
+ # Remove commas from numbers like 1,000,000
72
+ text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text)
73
+ # Remove decimals (e.g., .00)
74
+ text = re.sub(r'\.\d+', '', text)
75
+ # Replace numbers with Somali words
76
+ def replace_num(match):
77
+ return number_to_words(match.group())
78
+ text = re.sub(r'\d+', replace_num, text)
79
+ # Replace special symbols
80
+ symbol_map = {
81
+ '$': 'doolar',
82
+ '=': 'egwal',
83
+ '+': 'balaas',
84
+ '#': 'haash'
85
+ }
86
+ for sym, word in symbol_map.items():
87
+ text = text.replace(sym, ' ' + word + ' ')
88
+ # Character normalization
89
  text = text.replace("KH", "qa").replace("Z", "S")
90
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
91
  text = text.replace("ZamZam", "SamSam")
92
  return text
93
 
94
  def tts(text):
95
+ paragraphs = text.strip().split("\n")
96
+ audio_list = []
97
+
98
+ for i, para in enumerate(paragraphs):
99
+ if not para.strip():
100
+ continue
101
+ norm_para = normalize_text(para)
102
+ inputs = tokenizer(norm_para, return_tensors="pt").to(device)
103
+ with torch.no_grad():
104
+ waveform = model(**inputs).waveform.squeeze().cpu().numpy()
105
+
106
+ # Add pause between paragraphs
107
+ if i < len(paragraphs) - 1:
108
+ pause = np.zeros(int(model.config.sampling_rate * 0.8)) # 0.8s pause
109
+ audio_list.append(np.concatenate((waveform, pause)))
110
+ else:
111
+ audio_list.append(waveform)
112
+
113
+ final_audio = np.concatenate(audio_list)
114
  filename = "output.wav"
115
+ scipy.io.wavfile.write(filename, rate=model.config.sampling_rate, data=(final_audio * 32767).astype(np.int16))
116
  return filename
117
 
118
+ # Gradio interface
119
  gr.Interface(
120
  fn=tts,
121
+ inputs=gr.Textbox(label="Geli qoraal Soomaali ah", lines=10, placeholder="Ku qor 1 ama in ka badan paragraph..."),
122
  outputs=gr.Audio(label="Codka TTS"),
123
  title="Somali TTS",
124
+ description="Ku qor qoraal Soomaaliyeed si aad u maqasho cod dabiici ah."
125
  ).launch()